Skip to content

Commit

Permalink
Merge pull request #10 from microsoft/algos
Browse files Browse the repository at this point in the history
New Algos added:

DRO
CSD
  • Loading branch information
divyat09 committed Sep 8, 2020
2 parents 38011b9 + 16483d4 commit 59852a2
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 26 deletions.
17 changes: 13 additions & 4 deletions algorithms/algo.py
Expand Up @@ -17,9 +17,10 @@
from utils.match_function import get_matched_pairs

class BaseAlgo():
def __init__(self, args, train_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, run, cuda):
def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, run, cuda):
self.args= args
self.train_dataset= train_dataset
self.val_dataset= val_dataset
self.test_dataset= test_dataset
self.train_domains= train_domains
self.total_domains= total_domains
Expand Down Expand Up @@ -52,7 +53,11 @@ def get_model(self):
phi= DomainBed( self.args.img_c )
if 'resnet' in self.args.model_name:
from models.resnet import get_resnet
phi= get_resnet(self.args.model_name, self.args.out_classes, self.args.method_name,
if self.args.method_name in ['csd', 'matchdg_ctr']:
fc_layer=0
else:
fc_layer= self.args.fc_layer
phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer,
self.args.img_c, self.args.pre_trained)

print('Model Architecture: ', self.args.model_name)
Expand Down Expand Up @@ -90,11 +95,15 @@ def get_match_function(self, epoch):

return data_match_tensor, label_match_tensor

def get_test_accuracy(self):
def get_test_accuracy(self, case):

#Test Env Code
test_acc= 0.0
test_size=0
if case == 'val':
dataset= self.val_dataset
elif case == 'test':
dataset= self.test_dataset

for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.test_dataset):
with torch.no_grad():
Expand All @@ -108,6 +117,6 @@ def get_test_accuracy(self):
test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
test_size+= y_e.shape[0]

print(' Accuracy: ', 100*test_acc/test_size )
print(' Accuracy: ', case, 100*test_acc/test_size )

return 100*test_acc/test_size
168 changes: 168 additions & 0 deletions algorithms/csd.py
@@ -0,0 +1,168 @@
import sys
import numpy as np
import argparse
import copy
import random
import json

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity


class CSD(BaseAlgo):
def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda):

super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda)

# H_dim as per the feature layer dimension of ResNet-18
## TODO: Make it customizable with arg parser
H_dim= 512
self.K, m, self.num_classes = 1, H_dim, self.args.out_classes
num_domains = self.total_domains

self.sms = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-3, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)

self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True)
self.cs_wt = torch.nn.Parameter(torch.normal(mean=0, std=1e-3, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True)

self.opt= optim.SGD([
{'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
{'params': self.sms },
{'params': self.sm_biases },
{'params': self.embs },
{'params': self.cs_wt }
], lr= self.args.lr, weight_decay= 5e-4, momentum= 0.9, nesterov=True )

self.criterion = torch.nn.CrossEntropyLoss()

def forward(self, x, y, di, eval_case=0):
x = self.phi(x)
w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :]
logits_common = torch.matmul(x, w_c) + b_c

if eval_case:
return logits_common

domains= di
c_wts = torch.matmul(domains, self.embs)

# B x K
batch_size = x.shape[0]
c_wts = torch.cat((torch.ones((batch_size, 1), dtype=torch.float).to(self.cuda)*self.cs_wt, c_wts), 1)
c_wts = torch.tanh(c_wts).to(self.cuda)
w_d, b_d = torch.einsum("bk,krl->brl", c_wts, self.sms), torch.einsum("bk,kl->bl", c_wts, self.sm_biases)
logits_specialized = torch.einsum("brl,br->bl", w_d, x) + b_d

specific_loss = self.criterion(logits_specialized, y)
class_loss = self.criterion(logits_common, y)

sms = self.sms
diag_tensor = torch.stack([torch.eye(self.K+1).to(self.cuda) for _ in range(self.num_classes)], dim=0)
cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0)
orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2)

loss = 0.5*class_loss + 0.5*specific_loss + orth_loss
return loss, logits_common

def epoch_callback(self, nepoch, final=False):
if nepoch % 100 == 0:
print (self.embs, torch.norm(self.sms[0]), torch.norm(self.sms[1]))

def train(self):

for epoch in range(self.args.epochs):

if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag):
data_match_tensor, label_match_tensor= self.get_match_function(epoch)

penalty_csd=0
train_acc= 0.0
train_size=0

perm = torch.randperm(data_match_tensor.size(0))
data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0)
label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0)
print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split))

#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset):
# print('Batch Idx: ', batch_idx)

self.opt.zero_grad()
loss_e= torch.tensor(0.0).to(self.cuda)

x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)

#Forward Pass
csd_loss, out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=0)
loss_e+= csd_loss
penalty_csd += float(loss_e)

#Backprorp
loss_e.backward(retain_graph=False)
self.opt.step()

del csd_loss
del loss_e
torch.cuda.empty_cache()

train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
train_size+= y_e.shape[0]


print('Train Loss Basic : ', penalty_csd )
print('Train Acc Env : ', 100*train_acc/train_size )
print('Done Training for epoch: ', epoch)

#Val Dataset Accuracy
self.val_acc.append( self.get_test_accuracy('val') )

#Test Dataset Accuracy
self.final_acc.append( self.get_test_accuracy('test') )

# Save the model's weights post training
self.save_model()


def get_test_accuracy(self, case):

#Test Env Code
test_acc= 0.0
test_size=0
if case == 'val':
dataset= self.val_dataset
elif case == 'test':
dataset= self.test_dataset

for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(dataset):
with torch.no_grad():
x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)

#Forward Pass
out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=1)

test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
test_size+= y_e.shape[0]

print(' Accuracy: ', case, 100*test_acc/test_size )

return 100*test_acc/test_size

def save_model(self):
# Store the weights of the model
torch.save(self.phi.state_dict(), self.base_res_dir + '/Model_' + self.post_string + '.pth')
# Store the parameters
torch.save(self.sms, self.base_res_dir + '/Sms_' + self.post_string + ".pt")
torch.save(self.sm_biases, self.base_res_dir + '/SmBiases_' + self.post_string + ".pt")
100 changes: 100 additions & 0 deletions algorithms/dro.py
@@ -0,0 +1,100 @@
import sys
import numpy as np
import argparse
import copy
import random
import json

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity

class DRO(BaseAlgo):
def __init__(self, args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda):

super().__init__(args, train_dataset, val_dataset, test_dataset, train_domains, total_domains, domain_size, training_list_size, base_res_dir, post_string, cuda)

def train(self):

for epoch in range(self.args.epochs):

if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag):
data_match_tensor, label_match_tensor= self.get_match_function(epoch)

penalty_erm=0
train_acc= 0.0
train_size=0

perm = torch.randperm(data_match_tensor.size(0))
data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0)
label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0)
print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split))

#Batch iteration over single epoch
for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset):
# print('Batch Idx: ', batch_idx)

self.opt.zero_grad()
loss_e= torch.tensor(0.0).to(self.cuda)

x_e= x_e.to(self.cuda)
y_e= torch.argmax(y_e, dim=1).to(self.cuda)
d_e= torch.argmax(d_e, dim=1).numpy()

#Forward Pass
out= self.phi(x_e)

erm_loss= torch.tensor(0.0).to(self.cuda)
if epoch > self.args.penalty_s:
# To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
total_batch_size= len(data_match_tensor_split)
if batch_idx >= total_batch_size:
break
curr_batch_size= data_match_tensor_split[batch_idx].shape[0]

data_match= data_match_tensor_split[batch_idx].to(self.cuda)
label_match= label_match_tensor_split[batch_idx].to(self.cuda)

for domain_idx in range(data_match.shape[1]):

data_idx= data_match[:,domain_idx,:,:,:]
feat_idx= self.phi( data_idx )

label_idx= label_match[:, domain_idx]
label_idx= label_idx.view(label_idx.shape[0])
erm_loss = torch.max(erm_loss, F.cross_entropy(feat_idx, label_idx.long()).to(self.cuda))

penalty_erm+= float(erm_loss)
loss_e += erm_loss

loss_e.backward(retain_graph=False)
self.opt.step()

del erm_loss
del loss_e
torch.cuda.empty_cache()

train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
train_size+= y_e.shape[0]


print('Train Loss Basic : ', penalty_erm )
print('Train Acc Env : ', 100*train_acc/train_size )
print('Done Training for epoch: ', epoch)

#Val Dataset Accuracy
self.val_acc.append( self.get_test_accuracy('val') )

#Test Dataset Accuracy
self.final_acc.append( self.get_test_accuracy('test') )

# Save the model's weights post training
self.save_model()

0 comments on commit 59852a2

Please sign in to comment.