In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import numpy as np
from PIL import Image
from torchsummary import summary
import pandas as pd

In [2]:
class DATA(Dataset):
    def __init__(self, img_path, label=True,transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])) :
        self.img_path = img_path
        self.transform = transform
        self.label = label
        state = os.path.basename(img_path)

        self.labels = pd.read_csv(os.path.join(os.path.dirname(img_path),'{}.csv'.format(state)),index_col='image_name')
        
        self.filepaths = []
        fns = glob.glob(os.path.join(img_path+'/*.png'))
        for i in fns:
            self.filepaths.append(i)
        self.len = len(self.filepaths)
    
    def __getitem__(self, index) :
        fn = self.filepaths[index]
        img = Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        
        if self.label:
            return img,self.labels['label'][os.path.basename(fn)]
        else:
            return img,-1
    
    def __len__(self):
        return self.len

In [3]:
# Load Dataset
# Data : img, label
mnistm_train = DATA(img_path = 'hw2_data/digits/mnistm/train')
mnistm_train_without_label = DATA(img_path = 'hw2_data/digits/mnistm/train', label=False)
mnistm_test = DATA(img_path = 'hw2_data/digits/mnistm/test')

svhn_train = DATA(img_path = 'hw2_data/digits/svhn/train')
svhn_train_without_label = DATA(img_path = 'hw2_data/digits/svhn/train', label=False)
svhn_test = DATA(img_path = 'hw2_data/digits/svhn/test')

usps_train = DATA(img_path = 'hw2_data/digits/usps/train')
usps_train_without_label = DATA(img_path = 'hw2_data/digits/svhn/train', label=False)
usps_test = DATA(img_path = 'hw2_data/digits/usps/test')


In [4]:
type(mnistm_train[0][1])

numpy.int64

In [5]:
batch_size = 128
num_epochs = 100
lr = 0.0002 
nc = 3
ndf = 28
ncls = 10
beta1 = 0.5

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# reference from https://discuss.pytorch.org/t/solved-reverse-gradients-in-backward-pass/3589/4
from torch.autograd import Function
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

def grad_reverse(x):
    return GradReverse.apply(x)

In [15]:
class DANN(nn.Module):
    def __init__(self):
        super(DANN, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 28 x 28
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 7 x 7
            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 2 x 2
        )
        self.realfake = nn.Sequential(
            nn.Conv2d(ndf * 8, 1, 2, 1, 0, bias=False),
            nn.Sigmoid()
        )
        self.cls = nn.Sequential(
            nn.Conv2d(ndf * 8, ncls, 2, 1, 0, bias=False),
            nn.Softmax()
        )

    def forward(self, input):
        input = self.main(input)
        rev_input = grad_reverse(input)
        dt = self.realfake(rev_input)
        cls = self.cls(input)
        return dt ,cls

In [16]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [17]:
netDANN = DANN().to(device)
netDANN.apply(weights_init)
print(netDANN)

DANN(
  (main): Sequential(
    (0): Conv2d(3, 28, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(28, 56, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(56, 112, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(112, 224, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (realfake): Sequential(
    (0): Conv2d(224, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)
    (1): Sigmoid()
  )
  (cls): Sequent

In [18]:
from torchsummary import summary
summary(netDANN,(3,28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 28, 14, 14]           1,344
         LeakyReLU-2           [-1, 28, 14, 14]               0
            Conv2d-3             [-1, 56, 7, 7]          25,088
       BatchNorm2d-4             [-1, 56, 7, 7]             112
         LeakyReLU-5             [-1, 56, 7, 7]               0
            Conv2d-6            [-1, 112, 4, 4]          56,448
       BatchNorm2d-7            [-1, 112, 4, 4]             224
         LeakyReLU-8            [-1, 112, 4, 4]               0
            Conv2d-9            [-1, 224, 2, 2]         401,408
      BatchNorm2d-10            [-1, 224, 2, 2]             448
        LeakyReLU-11            [-1, 224, 2, 2]               0
           Conv2d-12              [-1, 1, 1, 1]             896
          Sigmoid-13              [-1, 1, 1, 1]               0
           Conv2d-14             [-1, 1

  input = module(input)


In [19]:
# 0:MNIST-M → USPS 72%
# 1:SVHN → MNIST-M 42%
# 2:USPS → SVHN 28%

trainLoader = {
    0:DataLoader(mnistm_train, batch_size=batch_size, shuffle=True) ,
    1:DataLoader(svhn_train, batch_size=batch_size, shuffle=True) ,
    2:DataLoader(usps_train, batch_size=batch_size, shuffle=True) ,
}
trainLoader_domain = {
    0:DataLoader(mnistm_train, batch_size=batch_size, shuffle=True) ,
    1:DataLoader(svhn_train, batch_size=batch_size, shuffle=True) ,
    2:DataLoader(usps_train, batch_size=batch_size, shuffle=True) ,
}
trainLoader_target = {
    0:DataLoader(usps_train_without_label, batch_size=batch_size, shuffle=True) ,
    1:DataLoader(mnistm_train_without_label, batch_size=batch_size, shuffle=True) ,
    2:DataLoader(svhn_train_without_label, batch_size=batch_size, shuffle=True) ,
}
testLoader = {
    0:DataLoader(usps_test, batch_size=batch_size, shuffle=False) ,
    1:DataLoader(mnistm_test, batch_size=batch_size, shuffle=False) ,
    2:DataLoader(svhn_test, batch_size=batch_size, shuffle=False) ,
}
test_threshold = {
    0:72,
    1:42,
    2:12
    }
mode_word = {
    0:('mnistm','usps'),
    1:('svhn','mnistm'),
    2:('usps','svhn'),
    
}

In [20]:
mode=-1

In [21]:
def train_bound(model, epoch, mode, testmode):
    mode = mode
    criterion_cls = nn.CrossEntropyLoss()

    optimizer = optim.Adam(netDANN.parameters(), lr=lr, betas=(beta1, 0.999))
    loss = []
    iters = 0
    for ep in range(epoch):
        for i, (data, label) in enumerate(trainLoader[mode]):
            model.train()
            model.zero_grad()
            data, label = data.to(device), label.to(device)
            # b_size = data.size(0)
            output = netDANN(data)[1].view(-1, ncls)
            
            err = criterion_cls(output, label)
            
            err.backward()
            loss.append(err)
            optimizer.step()
            
            if(i%50==0):
                print("{}/{} \tloss:{} ".format(
                    i, len(trainLoader[mode]), err  
                ))

            iters+=1
        
        test_bound(model, ep , testmode)


    

In [22]:
def test_bound(model, ep, mode):
    
    mode = mode
    criterion_cls = nn.CrossEntropyLoss()
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in testLoader[mode]:
            data, target = data.to(device), target.to(device)
            output = model(data)[1].view(-1,ncls)
            val_loss += criterion_cls(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_loss /= len(testLoader[mode])
    len_dataset = len(testLoader[mode].dataset)
    acc = 100. * correct / len_dataset
    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len_dataset, acc
        ))

In [23]:
def train(model, epoch,mode):
    mode = mode
    criterion_dt = nn.BCELoss()
    criterion_cls = nn.CrossEntropyLoss()

    domain_label = 1.
    target_label = 0.

    optimizer = optim.Adam(netDANN.parameters(), lr=lr, betas=(beta1, 0.999))
    
    loss_dt = []
    loss_cls = []
    loss = []
    iters = 0
    # train class and domain
    for ep in range(epoch):
        for i, (domain,target) in enumerate(zip(trainLoader_domain[mode],trainLoader_target[mode])):
            model.train()
            model.zero_grad()
            data, label = domain
            data, label = data.to(device), label.to(device)

            b_size = data.size(0)
            domain_dt = torch.full((b_size,), domain_label, dtype=torch.float,device=device)

            output_dt = netDANN(data)[0].view(-1)
            output_cls = netDANN(data)[1].view(-1,ncls)
            err_domain_dt = criterion_dt(output_dt, domain_dt)
            err_domain_cls = criterion_cls(output_cls, label)
            err_domain = err_domain_cls + err_domain_dt
            
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_dt: %.4f\tLoss_cls: %.4f\tLoss: %.4f'
                  % (ep, epoch, i, len(trainLoader_domain[mode]),
                     err_domain_dt.item(), err_domain_cls.item(), err_domain.item()))

            loss_dt.append(err_domain_dt)
            loss_cls.append(err_domain_cls)
            loss.append(err_domain) 

            data = target[0].to(device)

            b_size = data.size(0)
            target_dt = torch.full((b_size,), target_label, dtype=torch.float,device=device)

            output_dt = netDANN(data)[0].view(-1)
            
            err_target_dt = criterion_dt(output_dt, target_dt)
            err_target = err_target_dt
 
            err = err_domain_cls + err_domain_dt + err_target_dt
            err.backward()
            optimizer.step()
            
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_dt: %.4f\tLoss: %.4f'
                  % (ep, epoch, i, len(trainLoader_target[mode]),
                     err_target_dt.item(), err_target.item()))

            loss_cls.append(err_target_dt)
            loss.append(err_target)
            iters += 1

        test(model,ep,mode)




In [24]:
def test(model,ep,mode):
    mode = mode
    criterion_cls = nn.CrossEntropyLoss()
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in testLoader[mode]:
            data, target = data.to(device), target.to(device)
            output = model(data)[1].view(-1,ncls)
            val_loss += criterion_cls(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_loss /= len(testLoader[mode])
    len_dataset = len(testLoader[mode].dataset)
    acc = 100. * correct / len_dataset
    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len_dataset, acc
        ))
    
    if(acc>test_threshold[mode]):
        torch.save(model,'p3/{}_{}_{}_{:.0f}.pth'.format(mode_word[mode][0],mode_word[mode][1],ep,acc))

In [25]:
# netDANN.apply(weights_init)
# train(netDANN, 30,0)
# netDANN.apply(weights_init)
# train(netDANN, 30,1)
# netDANN.apply(weights_init)
# train(netDANN, 30,2)


In [26]:
model = torch.load('usps.pth')

model()
