In [1]:
import pytorch_lightning as pl
pl.seed_everything(10)

from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from dataset.digit import ChanDup, CENDataset
from pl_module.gcada import LitGCADA

Global seed set to 10


In [2]:
class GCADAConfig:
    checkpoint_filename = [
        "gcada-lenet-epoch=00-task=digit-val_loss_tgt=3.9711.ckpt",
        # "gcada-lenet-epoch=02-task=digit-val_loss=0.0000.ckpt",
        # "gcada-lenet-epoch=04-task=digit-val_loss=0.0000.ckpt",
        # "gcada-lenet-epoch=06-task=digit-val_loss=0.0000.ckpt",
    ]
    lambda_noise: None
    device="cuda:1"
    
    root="/shared/lorenzo/mnist-zsda/MNIST_G"
    root_tgt="/shared/lorenzo/mnist-zsda/MNIST_C"
    
    model_name="gcada"
    backbone_name="lenet"
    
    root_tgt_train="/shared/lorenzo/mnist-zsda/FashionMNIST_C"
    num_blocks=9
    hidden_dim_dsc=64
    lambda_idt=1.0
    lambda_sem=0.0
    transformation="rotate"
    beta1=0.5
    pretrained="/root/dezsda/checkpoints/timm-lenet-epoch=03-task=digit-val_loss=0.0297.ckpt"
    fix_block_up=False
    lambda_sem_idt=10.0
    sem_idt_per_epoch=1

    lr = 1e-3
    optimizer = "adam"
    batch_size = 64
    max_epochs = 50
    grad_accum = 1
    es_patience = None
    task="digit"
    img_size = 28
    fold_no = 0
    num_workers = 8
    channels = 3
    logger = True
    seed = 42
    project = "csi-har"
    checkpoint_dir = "/root/dezsda/checkpoints"
    gpus = 1
    num_classes=10
    num_test_sets=4

    assert root.split("/")[-1].split("_")[0] == root_tgt.split("/")[-1].split("_")[0]
    translated_dir=f"/shared/lorenzo/mnist-zsda/{root.split('/')[-1].split('_')[0]}_{root.split('/')[-1].split('_')[-1]}{root_tgt.split('/')[-1].split('_')[-1]}"
    
config = GCADAConfig()

In [3]:
for cf in config.checkpoint_filename:
    pl_model = LitGCADA.load_from_checkpoint(f"./checkpoints/{cf}", config=config)
    model = pl_model.model
    model.eval()

    transform = transforms.Compose([
        # transforms.Resize(32),
        transforms.CenterCrop((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
        ChanDup(),
    ])

    dataset = CENDataset(root=config.root_tgt_train, train=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=config.batch_size,
                            shuffle=True, pin_memory=True, 
                            num_workers=config.num_workers, drop_last=True)

    model = model.to(config.device)
    translated_inputs_tgt = None
    translated_targets_tgt = None
    for batch in tqdm(dataloader):
        inputs_src, targets_src = batch
        inputs_src = inputs_src.to(config.device)
        targets_src = targets_src.to(config.device)

        outputs = model.get_feature(inputs_src)

        if translated_inputs_tgt is not None:
            translated_inputs_tgt = torch.cat([outputs.detach(), translated_inputs_tgt], dim=0)
            translated_targets_tgt = torch.cat([targets_src.detach(), translated_targets_tgt], dim=0)
        else:
            translated_inputs_tgt = outputs
            translated_targets_tgt = targets_src

        del outputs
        del targets_src

    dataset = datasets.MNIST(root=config.root, train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=config.batch_size,
                            shuffle=True, pin_memory=True, 
                            num_workers=config.num_workers, drop_last=True)

    model = model.to(config.device)
    translated_inputs_src = None
    translated_targets_src = None
    for batch in tqdm(dataloader):
        inputs_src, targets_src = batch
        inputs_src = inputs_src.to(config.device)
        targets_src = targets_src.to(config.device)

        outputs = model.get_feature(inputs_src)

        if translated_inputs_src is not None:
            translated_inputs_src = torch.cat([outputs.detach(), translated_inputs_src], dim=0)
            translated_targets_src = torch.cat([targets_src.detach(), translated_targets_src], dim=0)
        else:
            translated_inputs_src = outputs
            translated_targets_src = targets_src

        del outputs
        del targets_src

    dataset = CENDataset(root=config.root_tgt, train=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=config.batch_size,
                            shuffle=True, pin_memory=True, 
                            num_workers=config.num_workers, drop_last=True)

    model = model.to(config.device)
    translated_inputs_tgt_test = None
    translated_targets_tgt_test = None
    for batch in tqdm(dataloader):
        inputs_src, targets_src = batch
        inputs_src = inputs_src.to(config.device)
        targets_src = targets_src.to(config.device)

        outputs = model.get_feature(inputs_src)

        if translated_inputs_tgt_test is not None:
            translated_inputs_tgt_test = torch.cat([outputs.detach(), translated_inputs_tgt_test], dim=0)
            translated_targets_tgt_test = torch.cat([targets_src.detach(), translated_targets_tgt_test], dim=0)
        else:
            translated_inputs_tgt_test = outputs
            translated_targets_tgt_test = targets_src

        del outputs
        del targets_src

    X_tgt = translated_inputs_tgt.cpu().detach().numpy()[:1000]
    X_src = translated_inputs_src.cpu().detach().numpy()[:1000]
    X_tgt_test = translated_inputs_tgt_test.cpu().detach().numpy()[:1000]

    _ = translated_targets_tgt.cpu().detach().numpy()[:1000].astype(str).tolist()
    y_src = translated_targets_src.cpu().detach().numpy()[:1000].astype(str).tolist()
    y_tgt_test = translated_targets_tgt_test.cpu().detach().numpy()[:1000].astype(str).tolist()

    X = np.concatenate([X_tgt, X_src, X_tgt_test], axis=0)
    X_embedded = TSNE(n_components=2, learning_rate='auto',
                    init='random').fit_transform(X)

    df = pd.DataFrame(X_embedded)
    df["group"] = ["tgt_train (FashionMNIST_C)"] * 1000 + ["src_train (MNIST_G)"] * 1000 +["tgt_test (MNIST_C)"] * 1000
    df["target"] = ["-"] * 1000 + y_src + y_tgt_test

    plt.rcParams["figure.figsize"] = (20,20)
    sns.scatterplot(data=df, x=0, y=1, hue="group", s=400)
    plt.legend(markerscale=3, fontsize=18)
    plt.savefig(f"{cf}.group.png")
    plt.clf()
    

    plt.rcParams["figure.figsize"] = (20,20)
    sns.scatterplot(data=df, x=0, y=1, hue="target", s=400, style="group")
    plt.legend(markerscale=3, fontsize=18)
    plt.savefig(f"{cf}.class.png")
    plt.clf()

100%|██████████| 937/937 [00:07<00:00, 118.94it/s]
100%|██████████| 937/937 [00:07<00:00, 126.09it/s]
100%|██████████| 937/937 [00:07<00:00, 124.63it/s]


<Figure size 1440x1440 with 0 Axes>