In [None]:
import numpy as np
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import DatasetFolder, VisionDataset
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import random
import pandas as pd
import glob
import scipy.misc
import imageio
import argparse

In [None]:
!gdown 1XUl4tqq3kKyWRe2lyHZ8s7U9yc_x3mWT -O hw2_data.zip

Downloading...
From: https://drive.google.com/uc?id=1XUl4tqq3kKyWRe2lyHZ8s7U9yc_x3mWT
To: /content/hw2_data.zip
100% 665M/665M [00:02<00:00, 296MB/s]


In [None]:
!unzip ./hw2_data.zip -d ./
!rm hw2_data.zip

Archive:  ./hw2_data.zip
replace ./hw2_data/.DS_Store? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


In [None]:
!pip install imageio==2.26.0 \
             matplotlib==3.7.0 \
             numpy==1.23.1 \
             Pillow==9.4.0 \
             scipy==1.10.0 \
             pandas==1.5.3 \
             torch==2.0.1 \
             torchvision==0.15.2 \
             gdown



In [None]:
np.random.seed(10901041)
torch.manual_seed(10901041)

<torch._C.Generator at 0x7973247a8690>

# Data

In [None]:
class DigitDataset(Dataset):
    def __init__(self, data_dir, split='train', transform=transforms.ToTensor()):
        self.data_dir = data_dir
        self.transform = transform
        self.split = split
        if split == 'all':
            self.image_files = sorted([f for f in os.listdir(data_dir) if f.endswith('.png')])
            self.labels = [-1 for _ in self.image_files]
        else:
            label_file = os.path.join(data_dir, f'{split}.csv')
#             print("labelfile",label_file)
            with open(label_file, 'r') as f:
                lines = f.readlines()
                self.labels = [int(line.strip().split(',')[1]) for line in lines[1:]]
                self.imgs = [line.strip().split(',')[0] for line in lines[1:]]
#                 print("self imgs",self.imgs)
                valid_image_files = [img for img in self.imgs]
#                 print("valid_img_files", valid_image_files)
                self.image_files = sorted(
                    [
                        os.path.join(data_dir, "data", x)
                        for x in os.listdir(os.path.join(data_dir, "data"))
                        if x in valid_image_files
                    ]
                )


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')  # in case there are grayscale images
#         image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]
#         if idx == 10:
#             print(f"idx: {idx}, Label: {label}, Image Size: {image.shape}")
        return image, label

In [None]:
# all_dataset = DigitDataset(data_dir="/kaggle/input/dlcv-hw2/hw2_data/digits/usps/data", split="all")
# all_dataloader = DataLoader(all_dataset, batch_size=32, shuffle=False)

# means = []
# stds = []
# for images, _ in all_dataloader:
#     means.append(torch.mean(images, dim=[0,2,3]))
#     stds.append(torch.std(images, dim=[0,2,3]))

# mean = torch.stack(means).mean(dim=0)
# std = torch.stack(stds).mean(dim=0)


In [None]:
# print(mean)
# print(std)
# print(len(all_dataset))

# mnistm mean/std
# tensor([0.4632, 0.4669, 0.4195])
# tensor([0.2520, 0.2365, 0.2605])

# svhn mean/std:
# tensor([0.4414, 0.4459, 0.4716])
# tensor([0.1960, 0.2000, 0.1976])

# usps mean/std
# tensor([0.2574, 0.2574, 0.2574])
# tensor([0.3524, 0.3524, 0.3524])

In [None]:
# usps_transform = transforms.Compose([
#         transforms.Resize(28),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.2574, 0.2574, 0.2574], std=[0.3524, 0.3524, 0.3524])
#     ])
# usps_train_dataset = DigitDataset(data_dir="/kaggle/input/dlcv-hw2/hw2_data/digits/usps", split="train", transform=usps_transform)
# usps_train_dataloader = torch.utils.data.DataLoader(usps_train_dataset, batch_size=32, shuffle=True)
# usps_val_dataset = DigitDataset(data_dir="/kaggle/input/dlcv-hw2/hw2_data/digits/usps", split="val", transform=usps_transform)
# usps_val_dataloader = torch.utils.data.DataLoader(usps_val_dataset, batch_size=32, shuffle=True)

In [None]:
# svhn_transform = transforms.Compose([
#         transforms.Resize(28),
# #         transforms.Lambda(rgb_to_grayscale),
#         transforms.Grayscale(num_output_channels=1),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.4414, 0.4459, 0.4716], std=[0.1960, 0.2000, 0.1976])
#     ])
# svhn_train_dataset = DigitDataset(data_dir="/kaggle/input/dlcv-hw2/hw2_data/digits/svhn", split="train", transform=svhn_transform)
# svhn_train_dataloader = torch.utils.data.DataLoader(svhn_train_dataset, batch_size=32, shuffle=True)
# svhn_val_dataset = DigitDataset(data_dir="/kaggle/input/dlcv-hw2/hw2_data/digits/svhn", split="val", transform=svhn_transform)
# svhn_val_dataloader = torch.utils.data.DataLoader(svhn_val_dataset, batch_size=32, shuffle=True)

In [None]:
# mnistm_transform = transforms.Compose([
#         transforms.Resize(28),
#         transforms.Grayscale(num_output_channels=1),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.4632, 0.4669, 0.4195], std=[0.2520, 0.2365, 0.2605])
#     ])
# mnistm_train_dataset = DigitDataset(data_dir="/kaggle/input/dlcv-hw2/hw2_data/digits/mnistm", split="train", transform=mnistm_transform)
# mnistm_train_dataloader = torch.utils.data.DataLoader(mnistm_train_dataset, batch_size=32, shuffle=True)
# mnistm_val_dataset = DigitDataset(data_dir="/kaggle/input/dlcv-hw2/hw2_data/digits/mnistm", split="val", transform=mnistm_transform)
# mnistm_val_dataloader = torch.utils.data.DataLoader(mnistm_val_dataset, batch_size=32, shuffle=True)

In [None]:
# print(len(usps_train_dataset))
# print(len(usps_val_dataset))

# Model

In [None]:
# Model Reference:
# https://github.com/NaJaeMin92/pytorch_DANN
# https://github.com/fungtion/DANN/tree/master

from torch.autograd import Function

class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class DANN(nn.Module):
    def __init__(self):
        super(DANN, self).__init__()
        self.feature = nn.Sequential()
        self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5))
        self.feature.add_module('f_bn1', nn.BatchNorm2d(64))
        self.feature.add_module('f_pool1', nn.MaxPool2d(2))
        self.feature.add_module('f_relu1', nn.ReLU(True))
        self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5))
        self.feature.add_module('f_bn2', nn.BatchNorm2d(50))
        self.feature.add_module('f_drop1', nn.Dropout())
        self.feature.add_module('f_pool2', nn.MaxPool2d(2))
        self.feature.add_module('f_relu2', nn.ReLU(True))

        self.class_classifier = nn.Sequential()
        self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100))
        self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))
        self.class_classifier.add_module('c_relu1', nn.ReLU(True))
        self.class_classifier.add_module('c_drop1', nn.Dropout())
        self.class_classifier.add_module('c_fc2', nn.Linear(100, 100))
        self.class_classifier.add_module('c_bn2', nn.BatchNorm1d(100))
        self.class_classifier.add_module('c_relu2', nn.ReLU(True))
        self.class_classifier.add_module('c_fc3', nn.Linear(100, 10))
        self.class_classifier.add_module('c_softmax', nn.LogSoftmax())

        self.domain_classifier = nn.Sequential()
        self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100))
        self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))
        self.domain_classifier.add_module('d_relu1', nn.ReLU(True))
        self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2))
        self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))

    def forward(self, input_data, alpha):
        input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
        feature = self.feature(input_data)
        feature = feature.view(-1, 50 * 4 * 4)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = self.class_classifier(feature)
        domain_output = self.domain_classifier(reverse_feature)

        return class_output, domain_output

# Train

In [None]:
data_path = ""
save_path = "./"
train_name="svhn"
val_name="svhn"
isDANN=False
mode = "svhn_upper_bound"

def optimizer_scheduler(optimizer, p):
#     Reference: https://github.com/NaJaeMin92/pytorch_DANN/blob/master/utils.py
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75
    return optimizer

#     Reference: https://github.com/fungtion/DANN/
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

if train_name == "mnistm":
    train_transform = transforms.Compose([
        transforms.Resize(28),
#         transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4632, 0.4669, 0.4195], std=[0.2520, 0.2365, 0.2605])
    ])
    train_data_dir = "./hw2_data/digits/mnistm"
elif train_name == "svhn":
    train_transform = transforms.Compose([
        transforms.Resize(28),
#         transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4414, 0.4459, 0.4716], std=[0.1960, 0.2000, 0.1976])
    ])
    train_data_dir="./hw2_data/digits/svhn"
elif train_name == "usps":
    train_transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.2574, 0.2574, 0.2574], std=[0.3524, 0.3524, 0.3524])
    ])
    train_data_dir="./hw2_data/digits/usps"
else:
    print("Wrong train_name")
train_dataset = DigitDataset(data_dir=train_data_dir, split="train", transform=train_transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

if isDANN:
    if val_name == "svhn":
        train_img_transform = transforms.Compose([
            transforms.Resize(28),
#             transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4414, 0.4459, 0.4716], std=[0.1960, 0.2000, 0.1976])
        ])
        train_img_data_dir="./hw2_data/digits/svhn"
    elif val_name == "usps":
        train_img_transform = transforms.Compose([
            transforms.Resize(28),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.2574, 0.2574, 0.2574], std=[0.3524, 0.3524, 0.3524])
        ])
        train_img_data_dir="./hw2_data/digits/usps"
    else:
        print("Wrong DANN val_name")
    train_img_dataset = DigitDataset(data_dir=train_img_data_dir, split="train", transform=train_img_transform)
    train_img_dataloader = torch.utils.data.DataLoader(train_img_dataset, batch_size=32, shuffle=True)

# if val_name == "svhn":
#     val_transform = transforms.Compose([
#         transforms.Resize(28),
# #         transforms.Grayscale(num_output_channels=1),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.4414, 0.4459, 0.4716], std=[0.1960, 0.2000, 0.1976])
#     ])
#     val_data_dir="./hw2_data/digits/svhn"
# elif val_name == "usps":
#     val_transform = transforms.Compose([
#         transforms.Resize(28),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.2574, 0.2574, 0.2574], std=[0.3524, 0.3524, 0.3524])
#     ])
#     val_data_dir="./hw2_data/digits/usps"
# else:
#     print("Wrong val_name")
# val_dataset = DigitDataset(data_dir=val_data_dir, split="val", transform=val_transform)
# val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)

model = DANN().to(device)
optimizer = torch.optim.RAdam(model.parameters(), lr=3e-4, weight_decay=1e-2)
class_loss = nn.CrossEntropyLoss()
domain_loss = nn.CrossEntropyLoss()

n_epoch = 200
for epoch in range(n_epoch):
    print(f"epoch: {epoch}")
    model.train()
    train_dataloader_iter = iter(train_dataloader)
    val_dataloader_iter = iter(val_dataloader)
    if isDANN:
        train_img_dataloader_iter = iter(train_img_dataloader)
    for i in tqdm(range(len(train_dataloader)), position=0, leave=True):
        p = float(i + epoch * len(train_dataloader)) / n_epoch / len(train_dataloader)
        alpha = 2. / (1. + np.exp(-10 * p)) - 1
        optimizer = optimizer_scheduler(optimizer=optimizer, p=p)
        optimizer.zero_grad()

        train_image, train_label = next(train_dataloader_iter)
        train_image, train_label = train_image.to(device), train_label.to(device)

        domain_label = torch.zeros(train_image.size(0)).long().to(device)
        class_output, domain_output = model(input_data=train_image, alpha=alpha)
        err_s_label = class_loss(class_output, train_label)
        err_s_domain = domain_loss(domain_output, domain_label)

        if isDANN:
            try:
                train_target_image = next(train_img_dataloader_iter)
            except StopIteration:
                train_img_dataloader_iter = iter(train_img_dataloader)
                train_target_image = next(train_img_dataloader_iter)
            (target_image,_,) = train_target_image
            target_image = target_image.to(device)
            domain_label = torch.ones(len(target_image))
            domain_label = domain_label.long().to(device)
            _, domain_output = model(input_data=target_image, alpha=alpha)
            err_t_domain = domain_loss(domain_output, domain_label)
            loss = err_t_domain + err_s_domain + err_s_label
        else:
            loss = err_s_domain + err_s_label
        loss.backward()
        optimizer.step()
    if epoch % 50 == 0 or epoch == n_epoch-1 :
        torch.save(model, "{0}/{1}_model_{2}.ckpt".format(save_path, mode, epoch))

device: cuda
epoch: 0


  0%|          | 0/1986 [00:00<?, ?it/s]

  input = module(input)


epoch: 1


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 2


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 3


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 4


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 5


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 6


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 7


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 8


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 9


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 10


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 11


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 12


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 13


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 14


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 15


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 16


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 17


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 18


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 19


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 20


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 21


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 22


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 23


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 24


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 25


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 26


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 27


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 28


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 29


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 30


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 31


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 32


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 33


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 34


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 35


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 36


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 37


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 38


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 39


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 40


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 41


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 42


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 43


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 44


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 45


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 46


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 47


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 48


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 49


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 50


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 51


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 52


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 53


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 54


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 55


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 56


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 57


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 58


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 59


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 60


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 61


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 62


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 63


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 64


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 65


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 66


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 67


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 68


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 69


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 70


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 71


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 72


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 73


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 74


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 75


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 76


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 77


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 78


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 79


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 80


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 81


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 82


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 83


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 84


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 85


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 86


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 87


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 88


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 89


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 90


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 91


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 92


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 93


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 94


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 95


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 96


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 97


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 98


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 99


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 100


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 101


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 102


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 103


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 104


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 105


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 106


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 107


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 108


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 109


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 110


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 111


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 112


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 113


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 114


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 115


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 116


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 117


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 118


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 119


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 120


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 121


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 122


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 123


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 124


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 125


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 126


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 127


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 128


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 129


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 130


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 131


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 132


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 133


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 134


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 135


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 136


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 137


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 138


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 139


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 140


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 141


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 142


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 143


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 144


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 145


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 146


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 147


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 148


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 149


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 150


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 151


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 152


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 153


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 154


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 155


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 156


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 157


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 158


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 159


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 160


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 161


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 162


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 163


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 164


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 165


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 166


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 167


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 168


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 169


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 170


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 171


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 172


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 173


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 174


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 175


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 176


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 177


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 178


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 179


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 180


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 181


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 182


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 183


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 184


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 185


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 186


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 187


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 188


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 189


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 190


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 191


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 192


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 193


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 194


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 195


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 196


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 197


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 198


  0%|          | 0/1986 [00:00<?, ?it/s]

epoch: 199


  0%|          | 0/1986 [00:00<?, ?it/s]

# val

In [None]:
model_path = "./usps_upper_bound_model_199.ckpt"
val_name = "usps"
if val_name == "svhn":
    val_transform = transforms.Compose([
        transforms.Resize(28),
#         transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4414, 0.4459, 0.4716], std=[0.1960, 0.2000, 0.1976])
    ])
    val_data_dir="./hw2_data/digits/svhn"
elif val_name == "usps":
    val_transform = transforms.Compose([
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.2574, 0.2574, 0.2574], std=[0.3524, 0.3524, 0.3524])
    ])
    val_data_dir="./hw2_data/digits/usps"
else:
    print("Wrong val_name")
val_dataset = DigitDataset(data_dir=val_data_dir, split="val", transform=val_transform)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)

model = torch.load(model_path, map_location="cuda:0")
model = model.eval()
alpha = 0
class_loss = nn.CrossEntropyLoss()

n_total = 0
n_correct = 0
valid_loss = []
for valid_data in tqdm(val_dataloader):
    valid_image, valid_label = valid_data
    valid_image = valid_image.to(device)
    class_output, _ = model(input_data=valid_image, alpha=alpha)

    loss = class_loss(class_output, valid_label.to(device))
    pred = class_output.data.max(1, keepdim=True)[1]
    pred = pred.cpu()
    valid_label = valid_label.cpu()
    n_correct += pred.eq(valid_label.data.view_as(pred)).cpu().sum()
    n_total += 32
    valid_loss.append(loss.item())

acc = n_correct.data.numpy() * 1.0 / n_total

print(f"{val_name} dataset acc: {acc:.5f}")
print(f"sum(valid_loss) / len(valid_loss): {sum(valid_loss) / len(valid_loss)}")

  0%|          | 0/47 [00:00<?, ?it/s]



usps dataset acc: 0.96543
sum(valid_loss) / len(valid_loss): 0.10044003832847515


In [None]:
activation = {}
def get_activation(name):
  def hook(model, input, output):
    activation[name] = output.detach()
  return hook

In [None]:
# tsne
model_path = "./dann_usps_model_199.ckpt"
from sklearn.manifold import TSNE
model.eval()
model = torch.load(model_path)
model.feature[-1].register_forward_hook(get_activation("layer_last"))

features = []
labels = []
domains = []

source_val_transform = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4632, 0.4669, 0.4195], std=[0.2520, 0.2365, 0.2605])
])
source_val_dataset = DigitDataset(data_dir="./hw2_data/digits/mnistm", split="val", transform=source_val_transform)
source_val_loader = torch.utils.data.DataLoader(source_val_dataset, batch_size=32, shuffle=True)

# Source domain
for x, y in source_val_loader:
    x = x.to(device)
    with torch.no_grad():
        model.feature(x)
    feature = activation["layer_last"].view(-1, 50*4*4).cpu().numpy()
    features.extend(feature)
    labels.extend(y.cpu().numpy())
    domains.extend(np.zeros(y.size(0)))

# Target domain
for x, y in val_dataloader:
    x = x.to(device)
    with torch.no_grad():
        model.feature(x)
    feature = activation["layer_last"].view(-1, 50*4*4).cpu().numpy()
    features.extend(feature)
    labels.extend(y.cpu().numpy())
    domains.extend(np.ones(y.size(0)))

features = np.array(features)
labels = np.array(labels)
domains = np.array(domains)


def visualize_with_tsne(features, labels, domains, title="t-SNE visualization"):
    tsne = TSNE(n_components=2, random_state=0, verbose=1)
    reduced_features = tsne.fit_transform(features)

    plt.figure(figsize=(12, 6))

    # By Class
    plt.subplot(1, 2, 1)
    for i in range(10):
        plt.scatter(reduced_features[labels == i, 0], reduced_features[labels == i, 1], s=10, label=str(i))
    plt.title(title + ' (By Class)')
    plt.legend()

    # By Domain
    plt.subplot(1, 2, 2)
    plt.scatter(reduced_features[domains == 0, 0], reduced_features[domains == 0, 1], s=10, label="Source")
    plt.scatter(reduced_features[domains == 1, 0], reduced_features[domains == 1, 1], s=10, label="Target")
    plt.title(title + ' (By Domain)')
    plt.legend()

    plt.show()

visualize_with_tsne(features, labels, domains, "DANN t-SNE Visualization")
