-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
152 lines (114 loc) · 3.92 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import random
from typing import List, Union
import numpy as np
import torch
from torch import Tensor
def pairwise_cosine_similarity(x: Tensor, y: Tensor, zero_diagonal: bool = False) -> Tensor:
r"""
Calculates the pairwise cosine similarity matrix
Args:
x: tensor of shape ``(batch_size, M, d)``
y: tensor of shape ``(batch_size, N, d)``
zero_diagonal: determines if the diagonal of the distance matrix should be set to zero
Returns:
A tensor of shape ``(batch_size, M, N)``
"""
x_norm = torch.linalg.norm(x, dim=2, keepdim=True)
y_norm = torch.linalg.norm(y, dim=2, keepdim=True)
distance = torch.matmul(torch.div(x, x_norm), torch.div(y, y_norm).permute(0, 2, 1))
if zero_diagonal:
assert x.shape[1] == y.shape[1]
mask = torch.eye(x.shape[1]).repeat(x.shape[0], 1, 1).bool().to(distance.device)
distance.masked_fill_(mask, 0)
return distance
def load_embed(path: str):
return None
def get_device() -> str:
r"""
Return the device available for execution
Returns:
``cpu`` for CPU or ``cuda`` for GPU
"""
if torch.cuda.is_available():
return 'cuda'
return 'cpu'
def to_device(batch: dict, device: object) -> dict:
r"""
Convert a batch to the specified device
Args:
batch: the batch needs to be converted.
device: GPU or CPU.
Returns:
A batch after converting
"""
converted_batch = dict()
for key in batch.keys():
converted_batch[key] = batch[key].to(device)
return converted_batch
def convert_arg_line_to_args(arg_line):
r"""
Convert a line of arguments into individual arguments
Args:
arg_line: a string read from the argument file.
Returns:
A list of arguments parsed from ``arg_line``
"""
arg_line = arg_line.strip()
if arg_line.startswith('#') or arg_line == '':
return []
for arg in arg_line.split():
if not arg.strip():
continue
yield arg
def set_seed(seed: int):
r"""
Sets the seed for generating random numbers
Args:
seed: seed value.
Returns:
None
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def padded_stack(tensors: Union[List[Tensor], List[List]], padding: int = 0):
r"""
Pad a list of variable length Tensors with ``padding``
Args:
tensors: list of variable length sequences.
padding: value for padded elements. Default: 0.
Returns:
Padded sequences
"""
if type(tensors[0]) == list:
tensors = [torch.tensor(tensor) for tensor in tensors]
n_dim = len(list(tensors[0].shape))
max_shape = [max([tensor.shape[d] for tensor in tensors]) for d in range(n_dim)]
padded_tensors = []
for tensor in tensors:
extended_tensor = expand_tensor(tensor, max_shape, fill=padding)
padded_tensors.append(extended_tensor)
return torch.stack(padded_tensors)
def expand_tensor(tensor: Tensor, extended_shape: List[int], fill: int = 0):
r"""
Expand a tensor to ``extended_shape``
Args:
tensor: tensor to expand.
extended_shape: new shape.
fill: value for padded elements. Default: 0.
Returns:
An expanded tensor
"""
tensor_shape = tensor.shape
expanded_tensor = torch.zeros(extended_shape, dtype=tensor.dtype).to(tensor.device)
expanded_tensor = expanded_tensor.fill_(fill)
if len(tensor_shape) == 1:
expanded_tensor[:tensor_shape[0]] = tensor
elif len(tensor_shape) == 2:
expanded_tensor[:tensor_shape[0], :tensor_shape[1]] = tensor
elif len(tensor_shape) == 3:
expanded_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2]] = tensor
elif len(tensor_shape) == 4:
expanded_tensor[:tensor_shape[0], :tensor_shape[1], :tensor_shape[2], :tensor_shape[3]] = tensor
return expanded_tensor