In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/NASA_Transfer_Learning/toy')
print("Current working directory: ", os.getcwd())

Current working directory:  /content/drive/MyDrive/Colab Notebooks/NASA_Transfer_Learning/toy


In [4]:
import sys

sys.path.append("/content/drive/MyDrive/Colab Notebooks/NASA_Transfer_Learning/") # Add root directory to PATH

In [6]:
from datetime import datetime
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from datasets import mnist, mnist_m
from models.ganin import SimpleClassifier
import trainers.ganin as trainer
# from trainer import train_ganin, test_ganin # getting called within /toy
from utils import transform, helper


# Set random seed to ensure deterministic behavior
helper.set_random_seed(seed=123)

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

Device:  cpu


In [5]:
# Install wandb ()

!pip install wandb --upgrade

Requirement already up-to-date: wandb in /usr/local/lib/python3.7/dist-packages (0.10.24)


In [6]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [7]:
# Hyperparameters 
config = dict(epochs=15,
              batch_size=64,
              learning_rate=2e-4,
              classes=10,
              img_size=28,
              dataset='mnist-mnist_m',
              architecture='ganin')

In [8]:
def main(hyperparameters):

    with wandb.init(project="uda-ganin-toy", config=hyperparameters, mode='online'):
        model = SimpleClassifier().to(device)
        wandb.watch(model, log='all', log_freq=100)

        # transforms
        transform_m = transform.get_transform(dataset="mnist")
        transform_mm = transform.get_transform(dataset="mnist_m")

        # dataloaders
        loaders_args = dict(
            batch_size=config["batch_size"],
            shuffle=True,
            num_workers=1,
            pin_memory=True,
        )

        train_src = torch.load("../data/mnist/processed/train.pt")
        trainloader_src = mnist.fetch(data=train_src,
                                    transform=transform_m,
                                    **loaders_args)

        # fetching testloader_m for symmetry but it is not needed in the code
        test_src = torch.load("../data/mnist/processed/test.pt")
        testloader_src = mnist.fetch(data=test_src,
                                transform=transform_m,
                                **loaders_args)

        train_tgt = torch.load("../data/mnist_m/processed/train.pt")
        trainloader_tgt = mnist_m.fetch(data=train_tgt,
                                    transform=transform_mm,
                                    **loaders_args)

        test_tgt = torch.load("../data/mnist_m/processed/test.pt")
        testloader_tgt = mnist_m.fetch(data=test_tgt,
                                    transform=transform_mm,
                                    **loaders_args)
        

        criterion_l = nn.CrossEntropyLoss() # class labels 1 to N class
        optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
        
        start_time = datetime.now()
        for epoch in range(1, config["epochs"]+1):
            
            loss = trainer.train_simple_classifier(model, epoch, config, trainloader_src,
                        criterion, optimizer, device)
            
            wandb.log({"epoch" : epoch, "train_loss" : loss}, step=epoch)

            accuracy = trainer.test_simple_classifier(model, testloader_tgt, device)

            wandb.log({"epoch" : epoch, "val_accuracy" : accuracy}, step=epoch)

        end_time = datetime.now()
        print(f"Train Time for {config['epochs']} epochs: {end_time - start_time}")

    return model

In [None]:
output = main(config)