Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from microsoft/algos
New Algos added: DRO CSD
- Loading branch information
Showing
10 changed files
with
423 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import sys | ||
import numpy as np | ||
import argparse | ||
import copy | ||
import random | ||
import json | ||
|
||
import torch | ||
from torch.autograd import grad | ||
from torch import nn, optim | ||
from torch.nn import functional as F | ||
from torchvision import datasets, transforms | ||
from torchvision.utils import save_image | ||
from torch.autograd import Variable | ||
import torch.utils.data as data_utils | ||
|
||
from .algo import BaseAlgo | ||
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity | ||
|
||
|
||
class CSD(BaseAlgo): | ||
def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): | ||
|
||
super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) | ||
|
||
# H_dim as per the feature layer dimension of ResNet-18 | ||
## TODO: Make it customizable with arg parser | ||
H_dim= 512 | ||
self.K, m, self.num_classes = 1, H_dim, self.args.out_classes | ||
num_domains = self.total_domains | ||
|
||
self.sms = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) | ||
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True) | ||
|
||
self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True) | ||
self.cs_wt = torch.nn.Parameter(torch.normal(mean=0, std=1e-3, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True) | ||
|
||
self.opt= optim.SGD([ | ||
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, | ||
{'params': self.sms }, | ||
{'params': self.sm_biases }, | ||
{'params': self.embs }, | ||
{'params': self.cs_wt } | ||
], lr= self.args.lr, weight_decay= 5e-4, momentum= 0.9, nesterov=True ) | ||
|
||
self.criterion = torch.nn.CrossEntropyLoss() | ||
|
||
def forward(self, x, y, di, eval_case=0): | ||
x = self.phi(x) | ||
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :] | ||
logits_common = torch.matmul(x, w_c) + b_c | ||
|
||
if eval_case: | ||
return logits_common | ||
|
||
domains= di | ||
c_wts = torch.matmul(domains, self.embs) | ||
|
||
# B x K | ||
batch_size = x.shape[0] | ||
c_wts = torch.cat((torch.ones((batch_size, 1), dtype=torch.float).to(self.cuda)*self.cs_wt, c_wts), 1) | ||
c_wts = torch.tanh(c_wts).to(self.cuda) | ||
w_d, b_d = torch.einsum("bk,krl->brl", c_wts, self.sms), torch.einsum("bk,kl->bl", c_wts, self.sm_biases) | ||
logits_specialized = torch.einsum("brl,br->bl", w_d, x) + b_d | ||
|
||
specific_loss = self.criterion(logits_specialized, y) | ||
class_loss = self.criterion(logits_common, y) | ||
|
||
sms = self.sms | ||
diag_tensor = torch.stack([torch.eye(self.K+1).to(self.cuda) for _ in range(self.num_classes)], dim=0) | ||
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0) | ||
orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2) | ||
|
||
loss = 0.5*class_loss + 0.5*specific_loss + orth_loss | ||
return loss, logits_common | ||
|
||
def epoch_callback(self, nepoch, final=False): | ||
if nepoch % 100 == 0: | ||
print (self.embs, torch.norm(self.sms[0]), torch.norm(self.sms[1])) | ||
|
||
def train(self): | ||
|
||
for epoch in range(self.args.epochs): | ||
|
||
if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag): | ||
data_match_tensor, label_match_tensor= self.get_match_function(epoch) | ||
|
||
penalty_csd=0 | ||
train_acc= 0.0 | ||
train_size=0 | ||
|
||
perm = torch.randperm(data_match_tensor.size(0)) | ||
data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) | ||
label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0) | ||
print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split)) | ||
|
||
#Batch iteration over single epoch | ||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset): | ||
# print('Batch Idx: ', batch_idx) | ||
|
||
self.opt.zero_grad() | ||
loss_e= torch.tensor(0.0).to(self.cuda) | ||
|
||
x_e= x_e.to(self.cuda) | ||
y_e= torch.argmax(y_e, dim=1).to(self.cuda) | ||
|
||
#Forward Pass | ||
csd_loss, out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=0) | ||
loss_e+= csd_loss | ||
penalty_csd += float(loss_e) | ||
|
||
#Backprorp | ||
loss_e.backward(retain_graph=False) | ||
self.opt.step() | ||
|
||
del csd_loss | ||
del loss_e | ||
torch.cuda.empty_cache() | ||
|
||
train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item() | ||
train_size+= y_e.shape[0] | ||
|
||
|
||
print('Train Loss Basic : ', penalty_csd ) | ||
print('Train Acc Env : ', 100*train_acc/train_size ) | ||
print('Done Training for epoch: ', epoch) | ||
|
||
#Val Dataset Accuracy | ||
self.val_acc.append( self.get_test_accuracy('val') ) | ||
|
||
#Test Dataset Accuracy | ||
self.final_acc.append( self.get_test_accuracy('test') ) | ||
|
||
# Save the model's weights post training | ||
self.save_model() | ||
|
||
|
||
def get_test_accuracy(self, case): | ||
|
||
#Test Env Code | ||
test_acc= 0.0 | ||
test_size=0 | ||
if case == 'val': | ||
dataset= self.val_dataset | ||
elif case == 'test': | ||
dataset= self.test_dataset | ||
|
||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(dataset): | ||
with torch.no_grad(): | ||
x_e= x_e.to(self.cuda) | ||
y_e= torch.argmax(y_e, dim=1).to(self.cuda) | ||
|
||
#Forward Pass | ||
out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=1) | ||
|
||
test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item() | ||
test_size+= y_e.shape[0] | ||
|
||
print(' Accuracy: ', case, 100*test_acc/test_size ) | ||
|
||
return 100*test_acc/test_size | ||
|
||
def save_model(self): | ||
# Store the weights of the model | ||
torch.save(self.phi.state_dict(), self.base_res_dir + '/Model_' + self.post_string + '.pth') | ||
# Store the parameters | ||
torch.save(self.sms, self.base_res_dir + '/Sms_' + self.post_string + ".pt") | ||
torch.save(self.sm_biases, self.base_res_dir + '/SmBiases_' + self.post_string + ".pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import sys | ||
import numpy as np | ||
import argparse | ||
import copy | ||
import random | ||
import json | ||
|
||
import torch | ||
from torch.autograd import grad | ||
from torch import nn, optim | ||
from torch.nn import functional as F | ||
from torchvision import datasets, transforms | ||
from torchvision.utils import save_image | ||
from torch.autograd import Variable | ||
import torch.utils.data as data_utils | ||
|
||
from .algo import BaseAlgo | ||
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity | ||
|
||
class DRO(BaseAlgo): | ||
def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda): | ||
|
||
super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda) | ||
|
||
def train(self): | ||
|
||
for epoch in range(self.args.epochs): | ||
|
||
if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag): | ||
data_match_tensor, label_match_tensor= self.get_match_function(epoch) | ||
|
||
penalty_erm=0 | ||
train_acc= 0.0 | ||
train_size=0 | ||
|
||
perm = torch.randperm(data_match_tensor.size(0)) | ||
data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0) | ||
label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0) | ||
print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split)) | ||
|
||
#Batch iteration over single epoch | ||
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset): | ||
# print('Batch Idx: ', batch_idx) | ||
|
||
self.opt.zero_grad() | ||
loss_e= torch.tensor(0.0).to(self.cuda) | ||
|
||
x_e= x_e.to(self.cuda) | ||
y_e= torch.argmax(y_e, dim=1).to(self.cuda) | ||
d_e= torch.argmax(d_e, dim=1).numpy() | ||
|
||
#Forward Pass | ||
out= self.phi(x_e) | ||
|
||
erm_loss= torch.tensor(0.0).to(self.cuda) | ||
if epoch > self.args.penalty_s: | ||
# To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split | ||
total_batch_size= len(data_match_tensor_split) | ||
if batch_idx >= total_batch_size: | ||
break | ||
curr_batch_size= data_match_tensor_split[batch_idx].shape[0] | ||
|
||
data_match= data_match_tensor_split[batch_idx].to(self.cuda) | ||
label_match= label_match_tensor_split[batch_idx].to(self.cuda) | ||
|
||
for domain_idx in range(data_match.shape[1]): | ||
|
||
data_idx= data_match[:,domain_idx,:,:,:] | ||
feat_idx= self.phi( data_idx ) | ||
|
||
label_idx= label_match[:, domain_idx] | ||
label_idx= label_idx.view(label_idx.shape[0]) | ||
erm_loss = torch.max(erm_loss, F.cross_entropy(feat_idx, label_idx.long()).to(self.cuda)) | ||
|
||
penalty_erm+= float(erm_loss) | ||
loss_e += erm_loss | ||
|
||
loss_e.backward(retain_graph=False) | ||
self.opt.step() | ||
|
||
del erm_loss | ||
del loss_e | ||
torch.cuda.empty_cache() | ||
|
||
train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item() | ||
train_size+= y_e.shape[0] | ||
|
||
|
||
print('Train Loss Basic : ', penalty_erm ) | ||
print('Train Acc Env : ', 100*train_acc/train_size ) | ||
print('Done Training for epoch: ', epoch) | ||
|
||
#Val Dataset Accuracy | ||
self.val_acc.append( self.get_test_accuracy('val') ) | ||
|
||
#Test Dataset Accuracy | ||
self.final_acc.append( self.get_test_accuracy('test') ) | ||
|
||
# Save the model's weights post training | ||
self.save_model() |
Oops, something went wrong.