From 366e580e2dee55894817825427b379f79bc43c8b Mon Sep 17 00:00:00 2001 From: divyat09 Date: Thu, 20 Aug 2020 13:35:15 +0000 Subject: [PATCH 1/3] DRO, ERM code added; validation acc added to train procedure --- algorithms/algo.py | 11 +++-- algorithms/dro.py | 100 ++++++++++++++++++++++++++++++++++++++++ algorithms/erm.py | 82 ++++++++++++++++++++++++++++++++ algorithms/erm_match.py | 9 ++-- algorithms/irm.py | 9 ++-- algorithms/match_dg.py | 9 ++-- evaluation/base_eval.py | 4 +- train.py | 26 +++++++++-- 8 files changed, 231 insertions(+), 19 deletions(-) create mode 100644 algorithms/dro.py create mode 100644 algorithms/erm.py diff --git a/algorithms/algo.py b/algorithms/algo.py index 1c2e8c3..1acfa76 100644 --- a/algorithms/algo.py +++ b/algorithms/algo.py @@ -17,9 +17,10 @@ from utils.match_function import get_matched_pairs class BaseAlgo(): - def __init__(self, args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, run, cuda): + def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, run, cuda): self.args= args self.train_dataset= train_dataset + self.val_dataset= val_dataset self.test_dataset= test_dataset self.train_domains= train_domains self.total_domains= total_domains @@ -90,11 +91,15 @@ def get_match_function(self, epoch): return data_match_tensor, label_match_tensor - def get_test_accuracy(self): + def get_test_accuracy(self, case): #Test Env Code test_acc= 0.0 test_size=0 + if case == 'val': + dataset= self.val_dataset + elif case == 'test': + dataset= self.test_dataset for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.test_dataset): with torch.no_grad(): @@ -108,6 +113,6 @@ def get_test_accuracy(self): test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item() test_size+= y_e.shape[0] - print(' Accuracy: ', 100*test_acc/test_size ) + print(' Accuracy: ', case, 100*test_acc/test_size ) return 100*test_acc/test_size \ No newline at end of file diff --git a/algorithms/dro.py b/algorithms/dro.py new file mode 100644 index 0000000..5e9aef9 --- /dev/null +++ b/algorithms/dro.py @@ -0,0 +1,100 @@ +import sys +import numpy as np +import argparse +import copy +import random +import json + +import torch +from torch.autograd import grad +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image +from torch.autograd import Variable +import torch.utils.data as data_utils + +from .algo import BaseAlgo +from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity + +class DRO(BaseAlgo): + def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): + + super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) + + def train(self): + + for epoch in range(self.args.epochs): + + if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag): + data_match_tensor, label_match_tensor= self.get_match_function(epoch) + + penalty_erm=0 + train_acc= 0.0 + train_size=0 + + perm = torch.randperm(data_match_tensor.size(0)) + data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) + label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0) + print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split)) + + #Batch iteration over single epoch + for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset): + # print('Batch Idx: ', batch_idx) + + self.opt.zero_grad() + loss_e= torch.tensor(0.0).to(self.cuda) + + x_e= x_e.to(self.cuda) + y_e= torch.argmax(y_e, dim=1).to(self.cuda) + d_e= torch.argmax(d_e, dim=1).numpy() + + #Forward Pass + out= self.phi(x_e) + + erm_loss= torch.tensor(0.0).to(self.cuda) + if epoch > self.args.penalty_s: + # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split + total_batch_size= len(data_match_tensor_split) + if batch_idx >= total_batch_size: + break + curr_batch_size= data_match_tensor_split[batch_idx].shape[0] + + data_match= data_match_tensor_split[batch_idx].to(self.cuda) + label_match= label_match_tensor_split[batch_idx].to(self.cuda) + + for domain_idx in range(data_match.shape[1]): + + data_idx= data_match[:,domain_idx,:,:,:] + feat_idx= self.phi( data_idx ) + + label_idx= label_match[:, domain_idx] + label_idx= label_idx.view(label_idx.shape[0]) + erm_loss = torch.max(erm_loss, F.cross_entropy(feat_idx, label_idx.long()).to(self.cuda)) + + penalty_erm+= float(erm_loss) + loss_e += erm_loss + + loss_e.backward(retain_graph=False) + self.opt.step() + + del erm_loss + del loss_e + torch.cuda.empty_cache() + + train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item() + train_size+= y_e.shape[0] + + + print('Train Loss Basic : ', penalty_erm ) + print('Train Acc Env : ', 100*train_acc/train_size ) + print('Done Training for epoch: ', epoch) + + #Val Dataset Accuracy + self.val_acc.append( self.get_test_accuracy('val') ) + + #Test Dataset Accuracy + self.final_acc.append( self.get_test_accuracy('test') ) + + # Save the model's weights post training + self.save_model() \ No newline at end of file diff --git a/algorithms/erm.py b/algorithms/erm.py new file mode 100644 index 0000000..1034a27 --- /dev/null +++ b/algorithms/erm.py @@ -0,0 +1,82 @@ +import sys +import numpy as np +import argparse +import copy +import random +import json + +import torch +from torch.autograd import grad +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image +from torch.autograd import Variable +import torch.utils.data as data_utils + +from .algo import BaseAlgo +from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity + +class Erm(BaseAlgo): + def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): + + super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) + + def train(self): + + for epoch in range(self.args.epochs): + + if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag): + data_match_tensor, label_match_tensor= self.get_match_function(epoch) + + penalty_erm=0 + penalty_ws=0 + train_acc= 0.0 + train_size=0 + + perm = torch.randperm(data_match_tensor.size(0)) + data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) + label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0) + print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split)) + + #Batch iteration over single epoch + for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset): + # print('Batch Idx: ', batch_idx) + + self.opt.zero_grad() + loss_e= torch.tensor(0.0).to(self.cuda) + + x_e= x_e.to(self.cuda) + y_e= torch.argmax(y_e, dim=1).to(self.cuda) + d_e= torch.argmax(d_e, dim=1).numpy() + + #Forward Pass + out= self.phi(x_e) + erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda) + loss_e+= erm_loss + penalty_erm += float(loss_e) + + #Backprorp + loss_e.backward(retain_graph=False) + self.opt.step() + + del erm_loss + del loss_e + torch.cuda.empty_cache() + + train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item() + train_size+= y_e.shape[0] + + + print('Train Loss Basic : ', penalty_erm ) + print('Train Acc Env : ', 100*train_acc/train_size ) + print('Done Training for epoch: ', epoch) + + #Val Dataset Accuracy + self.val_acc.append( self.get_test_accuracy('val') ) + + #Test Dataset Accuracy + self.final_acc.append( self.get_test_accuracy('test') ) + + # Save the model's weights post training + self.save_model() \ No newline at end of file diff --git a/algorithms/erm_match.py b/algorithms/erm_match.py index e80a2fe..fd8e7f9 100644 --- a/algorithms/erm_match.py +++ b/algorithms/erm_match.py @@ -18,9 +18,9 @@ from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity class ErmMatch(BaseAlgo): - def __init__(self, args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): + def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): - super().__init__(args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) + super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) def train(self): @@ -127,8 +127,11 @@ def train(self): print('Train Acc Env : ', 100*train_acc/train_size ) print('Done Training for epoch: ', epoch) + #Val Dataset Accuracy + self.val_acc.append( self.get_test_accuracy('val') ) + #Test Dataset Accuracy - self.final_acc.append( self.get_test_accuracy() ) + self.final_acc.append( self.get_test_accuracy('test') ) # Save the model's weights post training self.save_model() diff --git a/algorithms/irm.py b/algorithms/irm.py index 316ec7c..a37e7dd 100644 --- a/algorithms/irm.py +++ b/algorithms/irm.py @@ -18,9 +18,9 @@ from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity, compute_irm_penalty class Irm(BaseAlgo): - def __init__(self, args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): + def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): - super().__init__(args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) + super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) def train(self): @@ -114,8 +114,11 @@ def train(self): print('Train Acc Env : ', 100*train_acc/train_size ) print('Done Training for epoch: ', epoch) + #Val Dataset Accuracy + self.val_acc.append( self.get_test_accuracy('val') ) + #Test Dataset Accuracy - self.final_acc.append( self.get_test_accuracy() ) + self.final_acc.append( self.get_test_accuracy('test') ) # Save the model's weights post training self.save_model() diff --git a/algorithms/match_dg.py b/algorithms/match_dg.py index 311db87..53daec3 100644 --- a/algorithms/match_dg.py +++ b/algorithms/match_dg.py @@ -20,9 +20,9 @@ from utils.match_function import get_matched_pairs class MatchDG(BaseAlgo): - def __init__(self, args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda, ctr_phase=1): + def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda, ctr_phase=1): - super().__init__(args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) + super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) self.ctr_phase= ctr_phase self.ctr_save_post_string= str(self.args.match_case) + '_' + str(self.args.match_interrupt) + '_' + str(self.args.match_flag) + '_' + str(self.run) + '_' + self.args.model_name @@ -327,8 +327,11 @@ def train_erm_phase(self): print('Train Acc Env : ', 100*train_acc/train_size ) print('Done Training for epoch: ', epoch) + #Val Dataset Accuracy + self.val_acc.append( self.get_test_accuracy('val') ) + #Test Dataset Accuracy - self.final_acc.append( self.get_test_accuracy() ) + self.final_acc.append( self.get_test_accuracy('test') ) # Save the model's weights post training self.save_model_erm_phase(run_erm) diff --git a/evaluation/base_eval.py b/evaluation/base_eval.py index 01e8b6c..f2abe48 100644 --- a/evaluation/base_eval.py +++ b/evaluation/base_eval.py @@ -59,7 +59,7 @@ def __init__(self, args, train_dataset, test_dataset, train_domains, self.args.ctr_model_name ) - if self.args.method_name == 'erm_match': + if self.args.method_name in ['erm_match', 'erm', 'irm', 'dro']: self.save_path= self.base_res_dir + '/Model_' + self.post_string elif self.args.method_name == 'matchdg_ctr': @@ -71,8 +71,6 @@ def __init__(self, args, train_dataset, test_dataset, train_domains, self.ctr_load_post_string + '/Model_' + self.post_string + '_' + str(run) ) - elif self.args.method_name == 'irm_match': - self.save_path= self.base_res_dir + '/Model_' + self.post_string self.phi= self.get_model() self.load_model() diff --git a/train.py b/train.py index 06db065..870eaad 100644 --- a/train.py +++ b/train.py @@ -137,7 +137,7 @@ if args.method_name == 'erm_match': from algorithms.erm_match import ErmMatch train_method= ErmMatch( - args, train_dataset, + args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, @@ -147,7 +147,7 @@ from algorithms.match_dg import MatchDG ctr_phase=1 train_method= MatchDG( - args, train_dataset, + args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, @@ -157,16 +157,34 @@ from algorithms.match_dg import MatchDG ctr_phase=0 train_method= MatchDG( - args, train_dataset, + args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, run, cuda, ctr_phase ) + elif args.method_name == 'erm': + from algorithms.erm import Erm + train_method= Erm( + args, train_dataset, val_dataset, + test_dataset, train_domains, + total_domains, domain_size, + training_list_size, base_res_dir, + run, cuda + ) elif args.method_name == 'irm': from algorithms.irm import Irm train_method= Irm( - args, train_dataset, + args, train_dataset, val_dataset, + test_dataset, train_domains, + total_domains, domain_size, + training_list_size, base_res_dir, + run, cuda + ) + elif args.method_name == 'dro': + from algorithms.dro import DRO + train_method= DRO( + args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, From 472a6e8387dc83dbcb01d6872807da3df7b1b10a Mon Sep 17 00:00:00 2001 From: divyat09 Date: Wed, 26 Aug 2020 14:05:24 +0000 Subject: [PATCH 2/3] CSD algorithm added --- algorithms/algo.py | 6 +- algorithms/csd.py | 161 ++++++++++++++++++++++++++++++++++++++++ evaluation/base_eval.py | 2 +- models/resnet.py | 15 ++-- train.py | 13 +++- 5 files changed, 187 insertions(+), 10 deletions(-) create mode 100644 algorithms/csd.py diff --git a/algorithms/algo.py b/algorithms/algo.py index 1acfa76..df966da 100644 --- a/algorithms/algo.py +++ b/algorithms/algo.py @@ -53,7 +53,11 @@ def get_model(self): phi= DomainBed( self.args.img_c ) if 'resnet' in self.args.model_name: from models.resnet import get_resnet - phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name, + if self.args.method_name in ['csd', 'matchdg_ctr']: + fc_layer=0 + else: + fc_layer= self.args.fc_layer + phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained) print('Model Architecture: ', self.args.model_name) diff --git a/algorithms/csd.py b/algorithms/csd.py new file mode 100644 index 0000000..84aca90 --- /dev/null +++ b/algorithms/csd.py @@ -0,0 +1,161 @@ +import sys +import numpy as np +import argparse +import copy +import random +import json + +import torch +from torch.autograd import grad +from torch import nn, optim +from torch.nn import functional as F +from torchvision import datasets, transforms +from torchvision.utils import save_image +from torch.autograd import Variable +import torch.utils.data as data_utils + +from .algo import BaseAlgo +from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity + + +class CSD(BaseAlgo): + def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): + + super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) + + # H_dim as per the feature layer dimension of ResNet-18 + ## TODO: Make it customizable with arg parser + H_dim= 512 + self.K, m, self.num_classes = 1, H_dim, self.args.out_classes + num_domains = self.total_domains + + self.sms = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) + self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) + + self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True) + self.cs_wt = torch.nn.Parameter(torch.normal(mean=0, std=1e-3, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True) + + self.opt= optim.SGD([ + {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, + {'params': self.sms }, + {'params': self.sm_biases }, + {'params': self.embs }, + {'params': self.cs_wt } + ], lr= self.args.lr, weight_decay= 5e-4, momentum= 0.9, nesterov=True ) + + self.criterion = torch.nn.CrossEntropyLoss() + + def forward(self, x, y, di, eval_case=0): + x = self.phi(x) + w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :] + logits_common = torch.matmul(x, w_c) + b_c + + if eval_case: + return logits_common + + domains= di + c_wts = torch.matmul(domains, self.embs) + + # B x K + batch_size = x.shape[0] + c_wts = torch.cat((torch.ones((batch_size, 1), dtype=torch.float).to(self.cuda)*self.cs_wt, c_wts), 1) + c_wts = torch.tanh(c_wts).to(self.cuda) + w_d, b_d = torch.einsum("bk,krl->brl", c_wts, self.sms), torch.einsum("bk,kl->bl", c_wts, self.sm_biases) + logits_specialized = torch.einsum("brl,br->bl", w_d, x) + b_d + + specific_loss = self.criterion(logits_specialized, y) + class_loss = self.criterion(logits_common, y) + + sms = self.sms + diag_tensor = torch.stack([torch.eye(self.K+1).to(self.cuda) for _ in range(self.num_classes)], dim=0) + cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0) + orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2) + + loss = 0.5*class_loss + 0.5*specific_loss + orth_loss + return loss, logits_common + + def epoch_callback(self, nepoch, final=False): + if nepoch % 100 == 0: + print (self.embs, torch.norm(self.sms[0]), torch.norm(self.sms[1])) + + def train(self): + + for epoch in range(self.args.epochs): + + if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag): + data_match_tensor, label_match_tensor= self.get_match_function(epoch) + + penalty_csd=0 + train_acc= 0.0 + train_size=0 + + perm = torch.randperm(data_match_tensor.size(0)) + data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) + label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0) + print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split)) + + #Batch iteration over single epoch + for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset): + # print('Batch Idx: ', batch_idx) + + self.opt.zero_grad() + loss_e= torch.tensor(0.0).to(self.cuda) + + x_e= x_e.to(self.cuda) + y_e= torch.argmax(y_e, dim=1).to(self.cuda) + + #Forward Pass + csd_loss, out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=0) + loss_e+= csd_loss + penalty_csd += float(loss_e) + + #Backprorp + loss_e.backward(retain_graph=False) + self.opt.step() + + del csd_loss + del loss_e + torch.cuda.empty_cache() + + train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item() + train_size+= y_e.shape[0] + + + print('Train Loss Basic : ', penalty_csd ) + print('Train Acc Env : ', 100*train_acc/train_size ) + print('Done Training for epoch: ', epoch) + + #Val Dataset Accuracy + self.val_acc.append( self.get_test_accuracy('val') ) + + #Test Dataset Accuracy + self.final_acc.append( self.get_test_accuracy('test') ) + + # Save the model's weights post training + self.save_model() + + + def get_test_accuracy(self, case): + + #Test Env Code + test_acc= 0.0 + test_size=0 + if case == 'val': + dataset= self.val_dataset + elif case == 'test': + dataset= self.test_dataset + + for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(dataset): + with torch.no_grad(): + x_e= x_e.to(self.cuda) + y_e= torch.argmax(y_e, dim=1).to(self.cuda) + + #Forward Pass + out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=1) + + test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item() + test_size+= y_e.shape[0] + + print(' Accuracy: ', case, 100*test_acc/test_size ) + + return 100*test_acc/test_size diff --git a/evaluation/base_eval.py b/evaluation/base_eval.py index f2abe48..c42633a 100644 --- a/evaluation/base_eval.py +++ b/evaluation/base_eval.py @@ -59,7 +59,7 @@ def __init__(self, args, train_dataset, test_dataset, train_domains, self.args.ctr_model_name ) - if self.args.method_name in ['erm_match', 'erm', 'irm', 'dro']: + if self.args.method_name in ['erm_match', 'erm', 'irm', 'dro', 'csd']: self.save_path= self.base_res_dir + '/Model_' + self.post_string elif self.args.method_name == 'matchdg_ctr': diff --git a/models/resnet.py b/models/resnet.py index 4f4397a..acc321e 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -13,7 +13,7 @@ def forward(self, x): return x -def get_resnet(model_name, classes, erm_base, num_ch, pre_trained): +def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained): if model_name == 'resnet18': model= torchvision.models.resnet18(pre_trained) n_inputs = model.fc.in_features @@ -23,13 +23,14 @@ def get_resnet(model_name, classes, erm_base, num_ch, pre_trained): n_inputs = model.fc.in_features n_outputs= classes - if erm_base == 'matchdg_ctr': - model.fc = Identity(n_inputs) -# model.fc= nn.Sequential( nn.Linear(n_inputs, n_inputs), -# nn.ReLU(), -# ) - else: + if fc_layer: model.fc = nn.Linear(n_inputs, n_outputs) + else: + print('Here') + model.fc = Identity(n_inputs) +# model.fc= nn.Sequential( nn.Linear(n_inputs, n_inputs), +# nn.ReLU(), +# ) if num_ch==1: model.conv1 = nn.Conv2d(1, 64, diff --git a/train.py b/train.py index 870eaad..0aa0f68 100644 --- a/train.py +++ b/train.py @@ -41,10 +41,12 @@ help='Height of the image in dataset') parser.add_argument('--img_w', type=int, default= 224, help='Width of the image in dataset') +parser.add_argument('--fc_layer', type=int, default= 1, + help='ResNet architecture customization; 0: No fc_layer with resnet; 1: fc_layer for classification with resnet') parser.add_argument('--match_layer', type=str, default='logit_match', help='rep_match: Matching at an intermediate representation level; logit_match: Matching at the logit level') parser.add_argument('--pos_metric', type=str, default='l2', - help='Cost to function to evaluate distance between two representations; Options: l1; l2; cos') + help='Cost function to evaluate distance between two representations; Options: l1; l2; cos') parser.add_argument('--rep_dim', type=int, default=250, help='Representation dimension for contrsative learning') parser.add_argument('--pre_trained',type=int, default=0, @@ -190,6 +192,15 @@ training_list_size, base_res_dir, run, cuda ) + elif args.method_name == 'csd': + from algorithms.csd import CSD + train_method= CSD( + args, train_dataset, val_dataset, + test_dataset, train_domains, + total_domains, domain_size, + training_list_size, base_res_dir, + run, cuda + ) #Train the method: It will save the model's weights post training and evalute it on test accuracy From 6b6ac24e28df3edb76ed39a8d9a163fc537658d2 Mon Sep 17 00:00:00 2001 From: divyat09 Date: Thu, 27 Aug 2020 08:59:22 +0000 Subject: [PATCH 3/3] CSD last layer specific parameters added to model save and load --- algorithms/csd.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/algorithms/csd.py b/algorithms/csd.py index 84aca90..637f45a 100644 --- a/algorithms/csd.py +++ b/algorithms/csd.py @@ -159,3 +159,10 @@ def get_test_accuracy(self, case): print(' Accuracy: ', case, 100*test_acc/test_size ) return 100*test_acc/test_size + + def save_model(self): + # Store the weights of the model + torch.save(self.phi.state_dict(), self.base_res_dir + '/Model_' + self.post_string + '.pth') + # Store the parameters + torch.save(self.sms, self.base_res_dir + '/Sms_' + self.post_string + ".pt") + torch.save(self.sm_biases, self.base_res_dir + '/SmBiases_' + self.post_string + ".pt")