/
train.py
46 lines (39 loc) · 2.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import numpy as np
torch.set_default_tensor_type('torch.cuda.FloatTensor')
def train(itr, dataset, args, model, optimizer, logger, device,writer):
model.train()
features, labels, pairs_id = dataset.load_data(n_similar=args.num_similar,similar_size=args.similar_size)
seq_len = np.sum(np.max(np.abs(features), axis=2) > 0, axis=1)
features = features[:,:np.max(seq_len),:]
features = torch.from_numpy(features).float().to(device)
labels = torch.from_numpy(labels).float().to(device)
#interative
pseudo_label = None
# for i in range(3):
# outputs = model(features,seq_len=seq_len,is_training=True,itr=itr)
# total_loss,pseudo_label = model.criterion(outputs,labels,seq_len=seq_len,device=device,logger=logger,opt=args,itr=itr,pairs_id=pairs_id,pseudo_label=pseudo_label)
# # print('Iteration: %d, Loss: %.3f' %(itr, total_loss.data.cpu().numpy()))
# optimizer.zero_grad()
# total_loss.backward()
# # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
# optimizer.step()
outputs = model(features,seq_len=seq_len,is_training=True,itr=itr,opt=args,target=labels)
total_loss,loss_dict = model.criterion(outputs,labels,seq_len=seq_len,device=device,logger=logger,opt=args,itr=itr,pairs_id=pairs_id,inputs=features)
# print('Iteration: %d, Loss: %.3f' %(itr, total_loss.data.cpu().numpy()))
# loss_dict = {
# 'loss': total_loss,
# 'loss_Cls': (loss_1_orig.mean()) + loss_2_orig_supp.mean(),
# 'loss_VCL': loss_3_supp_Contrastive,
# 'loss_SCL': loss_guide
# }
writer.add_scalar('loss',loss_dict['loss'],itr)
writer.add_scalar('loss_Cls',loss_dict['loss_Cls'],itr)
writer.add_scalar('loss_CLL',loss_dict['loss_CLL'],itr)
writer.add_scalar('loss_VCL',loss_dict['loss_VCL'],itr)
writer.add_scalar('loss_SCL',loss_dict['loss_SCL'],itr)
optimizer.zero_grad()
total_loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
return total_loss.data.cpu().numpy()