-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
138 lines (107 loc) · 4.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import re
import torch
import numpy as np
from torch import nn
import pandas as pd
from torch.utils.data import DataLoader
from torch.nn import functional as F
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
class QuantileLoss(nn.Module):
## From: https://medium.com/the-artificial-impostor/quantile-regression-part-2-6fdbc26b2629
def __init__(self, quantiles):
##takes a list of quantiles
super().__init__()
self.quantiles = quantiles
def forward(self, preds, target):
assert not target.requires_grad
assert preds.size(0) == target.size(0)
losses = []
for i, q in enumerate(self.quantiles):
errors = target - preds[:, i]
losses.append(
torch.max(
(q-1) * errors,
q * errors
).unsqueeze(1))
loss = torch.mean(
torch.sum(torch.cat(losses, dim=1), dim=1))
return loss
def training_step(model, batch, device):
images, labels = batch
images, labels = images.to(device), labels.to(device)
out, *_ = model(images) # Generate predictions
loss = F.l1_loss(out, labels) # Calculate loss
return loss
def validation_step(model, batch, device):
images, labels= batch
images, labels = images.to(device), labels.to(device)
out, *_ = model(images) # Generate predictions
loss = F.l1_loss(out, labels) # Calculate loss
return {'Loss': loss.detach()}
def validation_epoch_end(model, outputs):
batch_losses = [x['Loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean() # Combine losses
return {'Loss': epoch_loss.item()}
def epoch_end(model, epoch, result):
print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}".format(
epoch, result['lrs'][-1], result['train_loss'], result['Loss']))
@torch.no_grad()
def evaluate(model, val_loader, device):
model.eval()
outputs = [validation_step(model, batch, device) for batch in val_loader]
return validation_epoch_end(model, outputs)
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def fit_one_cycle(epochs, model, train_loader, val_loader, device, save_path, lr=0.01):
best_loss = np.inf
torch.cuda.empty_cache()
history = []
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
for epoch in range(epochs):
model.train()
train_losses = []
lrs = []
for batch in train_loader:
loss = training_step(model, batch, device)
train_losses.append(loss)
loss.backward()
optimizer.step()
optimizer.zero_grad()
lrs.append(get_lr(optimizer))
# Validation phase
result = evaluate(model, val_loader, device)
result['train_loss'] = torch.stack(train_losses).mean().item()
result['lrs'] = lrs
epoch_end(model, epoch, result)
history.append(result)
sched.step(result['Loss'])
if best_loss > result['Loss']:
best_loss = result['Loss']
torch.save(model.state_dict(), save_path)
return history
def inference_step(model, batch, device):
images, labels= batch
images, labels = images.to(device), labels.to(device)
out, *_ = model(images) # Generate predictions
return out
@torch.no_grad()
def predict(model, val_loader, device):
model.eval()
outputs = [inference_step(model, batch, device) for batch in val_loader]
return torch.cat(outputs, axis = 0)
def clean_text(text):
# lower case characters only
text = text.lower()
# remove urls
text = re.sub('http\S+', ' ', text)
# only alphabets, spaces and apostrophes
text = re.sub("[^a-z' ]+", ' ', text)
# remove all apostrophes which are not used in word contractions
text = ' ' + text + ' '
text = re.sub("[^a-z]'|'[^a-z]", ' ', text)
split_sentence = text.split()
filtered_sentence = [w for w in split_sentence if not w.lower() in stop_words]
return filtered_sentence