In [0]:
import torchvision.transforms as tvtf
import torchvision as tv
import torch
import torch.nn as nn

In [0]:
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [0]:
image_size= 28
batch_size = 64

tf_source = tvtf.Compose([
    tvtf.Resize(image_size),
    tvtf.ToTensor(),
    tvtf.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5))
])


tf_target = tvtf.Compose([
    tvtf.Resize(image_size),
    tvtf.ToTensor(),
    tvtf.Normalize(mean=(0.1307,), std=(0.1307,))
])

In [0]:
dt_target = tv.datasets.MNIST(root='C:/Users/mtech _1/Desktop/gsoc',train=True, transform=tf_target,download=True)

In [0]:
ds_source = tv.datasets.SVHN(root='C:/Users/mtech _1/Desktop/gsoc',split='train', transform=tf_source,download=True)

Using downloaded and verified file: C:/Users/mtech _1/Desktop/gsoc/train_32x32.mat


In [0]:
dl_source = torch.utils.data.DataLoader(ds_source, batch_size)
dl_target = torch.utils.data.DataLoader(dt_target, batch_size)

In [0]:
from torch.autograd import Function

class GradientReversalFn(Function):
  @staticmethod
  def forward(self, x, alpha):
    self.alpha=alpha
    
    return x.view_as(x)
  
  @staticmethod
  def backward(self, grad_output):
    output = grad_output.neg()*self.alpha
    
    return output, None

In [0]:
class DACNN(nn.Module):
  
  def __init__(self):
    super().__init__()
    self.feature_extractor=nn.Sequential(
        nn.Conv2d(3,64,kernel_size=5),
        nn.BatchNorm2d(64), nn.MaxPool2d(2),
        nn.ReLU(True),
        nn.Conv2d(64,64, kernel_size=5),
        nn.BatchNorm2d(64),nn.Dropout2d(),nn.MaxPool2d(2),
        nn.ReLU(True),
        nn.Conv2d(64,128,kernel_size=4),
        
    )
    
    self.class_classifier=nn.Sequential(
        nn.Linear(128*1*1,3072), nn.BatchNorm1d(3072),nn.Dropout2d(),
        nn.ReLU(True),
        nn.Linear(3072,2048), nn.BatchNorm1d(2048),
        nn.ReLU(True),
        nn.Linear(2048,10),
        nn.LogSoftmax(dim=1),
    )
    
    self.domain_classifier=nn.Sequential(
        nn.Linear(128*1*1,1024),nn.BatchNorm1d(1024),
        nn.ReLU(True),
        nn.Linear(1024,1024),
        nn.ReLU(True),
        nn.Linear(1024,10),
        nn.LogSoftmax(dim=1),
    )
    
    
    
  def forward(self, x, grl_lambda=1.0):
    
    x= x.expand(x.data.shape[0], 3, image_size, image_size)
    
    features = self.feature_extractor(x)
    features= features.view(-1,128*1*1)
    reverse_features= GradientReversalFn.apply(features, grl_lambda)
    
    class_pred = self.class_classifier(features)
    domain_pred = self.domain_classifier(reverse_features)
    
    return class_pred, domain_pred

In [0]:

model = DACNN()

x_s, y_s = next(iter(dl_source))
x_t, y_t = next(iter(dl_target))

print('source domain: ', x_s.shape, y_s.shape)
print('target domain: ', x_t.shape, y_t.shape)

model(x_s)
model(x_t)

source domain:  torch.Size([64, 3, 28, 28]) torch.Size([64])
target domain:  torch.Size([64, 1, 28, 28]) torch.Size([64])


(tensor([[-2.2455, -2.4973, -2.1269, -2.5838, -2.1413, -2.3132, -2.0836, -2.4933,
          -2.4097, -2.2660],
         [-2.3298, -2.4438, -2.1230, -2.1562, -2.5305, -2.3406, -2.1839, -2.2907,
          -2.4515, -2.2579],
         [-2.3060, -2.3156, -2.1187, -2.3726, -2.3722, -2.5651, -2.2848, -2.3332,
          -2.5099, -1.9806],
         [-2.3211, -2.4496, -2.2610, -2.3752, -2.5351, -2.3456, -2.1679, -2.0964,
          -2.5222, -2.0744],
         [-2.1515, -2.6835, -2.2961, -2.4822, -2.4471, -2.5603, -2.4422, -2.2533,
          -1.8362, -2.1550],
         [-2.1097, -2.2974, -2.5170, -2.2955, -2.4851, -2.1845, -2.3403, -2.1935,
          -2.2466, -2.4376],
         [-2.2209, -2.3716, -2.3230, -2.2865, -2.2252, -2.7136, -1.9814, -2.5168,
          -2.4242, -2.1449],
         [-2.3101, -2.7095, -2.5525, -2.6182, -2.3795, -2.1501, -2.2054, -1.5271,
          -2.6245, -2.5976],
         [-2.3390, -2.2948, -2.2797, -2.3904, -2.4056, -2.1884, -2.1112, -2.2347,
          -2.5538, -2.2959],
 

In [0]:
import torch.optim as optim
lr = 1e-3
n_epochs=400

model=DACNN()
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.08, momentum=0.9)

loss_fn_class = torch.nn.NLLLoss()
loss_fn_domain= torch.nn.NLLLoss()



In [0]:
batch_size= 2000
dl_source= torch.utils.data.DataLoader(ds_source, batch_size)
dl_target= torch.utils.data.DataLoader(dt_target, batch_size)

max_batches= min(len(dl_source), len(dl_target))

In [0]:
len(ds_source), len(dt_target)

(73257, 60000)

In [0]:
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [0]:
len(dl_source), len(dl_target)

(37, 30)

In [0]:
steps=0
import numpy as np
for epoch_idx in range(n_epochs):
  print(f'Epoch{epoch_idx+1:04d}/ {n_epochs:04d}', end= '\n=============\n')
  dl_source_iter= iter(dl_source)
  dl_target_iter= iter(dl_target)
  
  for batch_idx in range(max_batches):
    steps+=1
    optimizer.zero_grad()
    
    p= float(batch_idx + epoch_idx*max_batches)/(n_epochs*max_batches)
    grl_lambda= 2./ (1.+np.exp(-10*p))-1
    #grl_lambda= 1
    
    
    x_s,y_s = next(dl_source_iter)
    x_s, y_s = x_s.to(device), y_s.to(device)
    y_s_domain= torch.zeros(batch_size, dtype= torch.long)
    y_s_domain= y_s_domain.to(device)
    
    class_pred, domain_pred = model(x_s, grl_lambda)
    loss_s_label=loss_fn_class(class_pred, y_s)
    loss_s_domain = loss_fn_domain(domain_pred, y_s_domain)
    
    x_t, _ = next(dl_target_iter)
    y_t_domain = torch.ones(batch_size, dtype= torch.long)
    x_t, y_t_domain = x_t.to(device), y_t_domain.to(device)
    
    _, domain_pred= model(x_t, grl_lambda)
    
    loss_t_domain= loss_fn_domain(domain_pred, y_t_domain)
    
    
    loss= loss_t_domain +loss_s_domain + loss_s_label
    loss.backward()
    optimizer.step()
    
    
    #if (steps%10) == 0:
  model.eval()
  test_loss1=0
  accuracy1=0
        
  test_loss2=0
  accuracy2=0
        
  for t_images, t_labels in iter(dl_target):
    t_images, t_labels = t_images.to(device), t_labels.to(device)
    t_logps,_= model(t_images, grl_lambda)
    t_loss= loss_fn_class(t_logps, t_labels)
    test_loss1 += loss.item()
          
          
    t_ps = torch.exp(t_logps)
          
    t_top_ps, t_top_class = t_ps.topk(1 ,dim=1)
    equality= t_top_class==t_labels.view(*t_top_class.shape)
    accuracy1 += torch.mean(equality.type(torch.FloatTensor)).item()
          
  for s_images, s_labels in iter(dl_source):
    s_images, s_labels = s_images.to(device), s_labels.to(device)
    s_logps,_= model(s_images, grl_lambda)
    s_loss= loss_fn_class(s_logps, s_labels)
    test_loss2 += loss.item()
          
          
    s_ps = torch.exp(s_logps)
          
    s_top_ps, s_top_class = s_ps.topk(1 ,dim=1)
    equality= s_top_class==s_labels.view(*s_top_class.shape)
    accuracy2 += torch.mean(equality.type(torch.FloatTensor)).item()
          
        
    
    
  print(f'[{batch_idx+1}/{max_batches}]'
        f'class_loss: {loss_s_label.item(): .4f}       '     f's_domain_loss: {loss_s_domain.item():.4f}   '
        f't_domain_loss:{loss_t_domain.item():.4f}       '     f'grl_lambda: {grl_lambda:.3f}    '
        f'Target accuracy: {accuracy1/len(dl_target):.3f}   '    f'source accuracy: {accuracy2/len(dl_source):.3f}'
        )
      
  model.train()
      
      
      
    
    #if batch_idx==2:
      #print('This is just a demo, stopping.....')
      #break 

Epoch0001/ 0400
[30/30]class_loss:  0.8885       s_domain_loss: 0.7965   t_domain_loss:0.4775       grl_lambda: 0.012    Target accuracy: 0.534   source accuracy: 0.416
Epoch0002/ 0400
[30/30]class_loss:  0.5801       s_domain_loss: 0.1747   t_domain_loss:0.0661       grl_lambda: 0.025    Target accuracy: 0.526   source accuracy: 0.226
Epoch0003/ 0400
[30/30]class_loss:  0.5131       s_domain_loss: 0.3824   t_domain_loss:0.7424       grl_lambda: 0.037    Target accuracy: 0.552   source accuracy: 0.226
Epoch0004/ 0400
[30/30]class_loss:  0.4795       s_domain_loss: 0.1931   t_domain_loss:0.3311       grl_lambda: 0.050    Target accuracy: 0.535   source accuracy: 0.242
Epoch0005/ 0400
[30/30]class_loss:  0.4730       s_domain_loss: 0.6028   t_domain_loss:0.4042       grl_lambda: 0.062    Target accuracy: 0.507   source accuracy: 0.281
Epoch0006/ 0400
[30/30]class_loss:  0.4665       s_domain_loss: 0.2895   t_domain_loss:0.6016       grl_lambda: 0.074    Target accuracy: 0.535   source ac

KeyboardInterrupt: ignored