In [1]:
# from google.colab import drive
# drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os

# os.chdir("/content/drive/MyDrive/Colab Notebooks/NASA_Transfer_Learning")

import sys
sys.path.append("../")

from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

import data_loader
from models.ganin import GaninModel

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
# random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2 ** 32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2 ** 32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2 ** 32 - 1)

# Pytorch performance tuninng guide - NVIDIA
torch.backends.cudnn.benchmark = True  # speeds up convolution operations

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

Device:  cpu


In [3]:
# Hyperparameters
IMG_SIZE = 28
BATCH_SIZE = 64
EPOCHS = 15
LR = 2e-4

In [4]:
# MNIST
transform_m = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset_m = datasets.MNIST(
    "data/mnist", train=True, download=False, transform=transform_m
)
trainloader_m = torch.utils.data.DataLoader(
    trainset_m, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True
)

testset_m = datasets.MNIST(
    "data/mnist", train=False, download=False, transform=transform_m
)
testloader_m = torch.utils.data.DataLoader(
    testset_m, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True
)

# MNIST-M
transform_mm = transforms.Compose(
    [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

DATA_DIR = "data/mnist_m/processed/"

trainloader_mm = data_loader.fetch(
    data_dir=os.path.join(DATA_DIR, "mnist_m_train.pt"),
    batch_size=BATCH_SIZE,
    transform=transform_mm,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
)

testloader_mm = data_loader.fetch(
    data_dir=os.path.join(DATA_DIR, "mnist_m_test.pt"),
    batch_size=BATCH_SIZE,
    transform=transform_mm,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
)

net = GaninModel().to(device)

criterion_l = nn.CrossEntropyLoss()
criterion_d = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)

num_batches = min(len(trainloader_m), len(trainloader_mm))  # ~60000/batch_size
print("No. of Batches: ", num_batches)

Device:  cuda:0
No. of Batches:  1875


In [5]:
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))
    print("Memory Allocated: (GB)")
    print("Allocated: ", round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1))
    print("Cached: ", round(torch.cuda.memory_reserved(0) / 1024 ** 3, 1))

Tesla V100-SXM2-16GB
Memory Allocated: (GB)
Allocated:  0.0
Cached:  0.0


In [6]:
test_accuracy = []
start_time = datetime.now()

for epoch in range(EPOCHS):

    running_loss_total = 0
    running_loss_l = 0
    running_loss_d = 0

    dataiter_mm = iter(trainloader_mm)
    dataiter_m = iter(trainloader_m)
    alpha = (2 / (1 + np.exp(-10 * ((epoch + 0.0) / EPOCHS)))) - 1
    print(f"alpha: {alpha}")

    net.train()
    for batch in range(1, num_batches + 1):
        loss_total = 0
        loss_d = 0
        loss_l = 0

        optimizer.zero_grad()
        # for source domain
        imgs, lbls = dataiter_m.next()
        imgs, lbls = imgs.to(device), lbls.to(device)
        imgs = torch.cat((imgs, imgs, imgs), 1)

        # with torch.cuda.amp.autocast():
        out_l, out_d = net(imgs, alpha)
        loss_l_src = criterion_l(out_l, lbls)
        actual_d = torch.zeros(out_d.shape).to(device)
        loss_d_src = criterion_d(out_d, actual_d)

        # for target domain
        imgs, lbls = dataiter_mm.next()
        imgs = imgs.to(device)

        # with torch.cuda.amp.autocast():
        _, out_d = net(imgs, alpha)
        actual_d = torch.ones(out_d.shape).to(device)
        loss_d_tgt = criterion_d(out_d, actual_d)

        loss_total = loss_d_src + loss_l_src + loss_d_tgt
        loss_total.backward()
        optimizer.step()

        running_loss_total += loss_total
        running_loss_d += loss_d_src + loss_d_tgt
        running_loss_l += loss_l_src

        if batch % 300 == 0:
            print(f"Epoch: {epoch}/{EPOCHS} Batch: {batch}/{num_batches}")
            print(f"Total Loss: {running_loss_total/batch}")
        #   print(f"Label Loss: {running_loss_l/batch}")
        #   print(f"Domain Loss: {running_loss_d/batch}")

    net.eval()
    test_loss = 0
    accuracy = 0

    with torch.no_grad():
        net.eval()
        for imgs, lbls in testloader_mm:
            imgs, lbls = imgs.to(device), lbls.to(device)
            # print(imgs.shape)
            # print(lbls.shape)

            logits, _ = net(imgs, alpha=0)
            # print(logits.shape)
            test_loss += criterion_l(logits, lbls)

            # derive which class index corresponds to max value
            preds_l = torch.max(logits, dim=1)[
                1
            ]  # [1]: indices(class) corresponding to max values
            equals = torch.eq(preds_l, lbls)  # count no. of correct class predictions
            accuracy += torch.mean(equals.float())

    test_accuracy.append(accuracy / len(testloader_mm))
    print(f"Test accuracy: {accuracy / len(testloader_mm)}")
    print("\n")

end_time = datetime.now()
duration = end_time - start_time
print(f"Training Time for {EPOCHS} epochs: {duration}")

alpha: 0.0
Epoch: 0/15 Batch: 300/1875
Total Loss: 2.290130376815796
Label Loss: 0.917344868183136
Domain Loss: 1.3727856874465942
Epoch: 0/15 Batch: 600/1875
Total Loss: 1.9377704858779907
Label Loss: 0.5933164358139038
Domain Loss: 1.3444538116455078
Epoch: 0/15 Batch: 900/1875
Total Loss: 1.7587906122207642
Label Loss: 0.45895248651504517
Domain Loss: 1.2998383045196533
Epoch: 0/15 Batch: 1200/1875
Total Loss: 1.6209685802459717
Label Loss: 0.38094764947891235
Domain Loss: 1.2400199174880981
Epoch: 0/15 Batch: 1500/1875
Total Loss: 1.50449538230896
Label Loss: 0.33010706305503845
Domain Loss: 1.174387812614441
Epoch: 0/15 Batch: 1800/1875
Total Loss: 1.4023813009262085
Label Loss: 0.2956162989139557
Domain Loss: 1.1067640781402588
Test accuracy: 0.47783544659614563


alpha: 0.32151273753163445
Epoch: 1/15 Batch: 300/1875
Total Loss: 1.341217279434204
Label Loss: 0.13582591712474823
Domain Loss: 1.2053916454315186
Epoch: 1/15 Batch: 600/1875
Total Loss: 1.3551908731460571
Label Loss: