In [1]:
import pytorch_lightning as pl
pl.seed_everything(10)

from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch
import torchvision
import torchmetrics

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn.functional as F

from dataset.digit import ChanDup, CENDataset
from pl_module.zsgcada import LitZSGCADA

Global seed set to 10


In [2]:
class ZSGCADAConfig:
    checkpoint_filename = [
        # FasionMNIST
        # "zsgcada-lenet-epoch=02-task=digit-val_loss_tgt=2.1088.ckpt",
        # "zsgcada-lenet-epoch=08-task=digit-val_loss_tgt=2.0557.ckpt",
        # "zsgcada-lenet-epoch=04-task=digit-val_loss_tgt=1.8896.ckpt",
        # "zsgcada-lenet-epoch=06-task=digit-val_loss_tgt=1.8589.ckpt",

        "zsgcada-lenet-epoch=00-task=digit-val_loss_tgt=2.6313.ckpt",
        "zsgcada-lenet-epoch=00-task=digit-val_loss_tgt=2.2142.ckpt",
        "zsgcada-lenet-epoch=00-task=digit-val_loss_tgt=2.0466.ckpt",
        "zsgcada-lenet-epoch=00-task=digit-val_loss_tgt=1.7286.ckpt",
    ]
    lambda_noise: None
    device="cuda:1"
    
    root="/shared/lorenzo/mnist-zsda/MNIST_G"
    root_tgt="/shared/lorenzo/mnist-zsda/MNIST_C"
    
    model_name="zsgcada"
    backbone_name="lenet"
    
    root_tgt_train="/shared/lorenzo/mnist-zsda/FashionMNIST_C"
    num_blocks=9
    hidden_dim_dsc=32
    lambda_idt=1.0
    lambda_sem=0.0
    transformation="rotate"
    beta1=0.5
    pretrained="/root/dezsda/checkpoints/timm-lenet-epoch=03-task=digit-val_loss=0.0297.ckpt"
    fix_block_up=False
    lambda_sem_idt=10.0
    sem_idt_per_epoch=1
    
    lr_enc_multi: 1.0
    lr_dsc_multi: 1.0
    lambda_rec: 1.0
    lambda_cross: 1.0

    lr = 1e-3
    optimizer = "adam"
    batch_size = 64
    max_epochs = 50
    grad_accum = 1
    es_patience = None
    task="digit"
    img_size = 28
    fold_no = 0
    num_workers = 8
    channels = 3
    logger = True
    seed = 42
    project = "csi-har"
    checkpoint_dir = "/root/dezsda/checkpoints"
    gpus = 1
    num_classes=10
    num_test_sets=4

    assert root.split("/")[-1].split("_")[0] == root_tgt.split("/")[-1].split("_")[0]
    translated_dir=f"/shared/lorenzo/mnist-zsda/{root.split('/')[-1].split('_')[0]}_{root.split('/')[-1].split('_')[-1]}{root_tgt.split('/')[-1].split('_')[-1]}"
    
config = ZSGCADAConfig()

In [3]:
# cf = config.checkpoint_filename[0]


In [4]:
transform = transforms.Compose([
    # transforms.Resize(32),
    transforms.CenterCrop((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
    ChanDup(),
])

dataset_mnist_c = CENDataset(root="/shared/lorenzo/mnist-zsda/MNIST_E", train=False, transform=transform)
dataloader = DataLoader(dataset_mnist_c, batch_size=config.batch_size,
                        shuffle=True, pin_memory=True, 
                        num_workers=config.num_workers, drop_last=True)

In [5]:
lambda_pp = 0.9
sos = []
posts = []
for cf in config.checkpoint_filename:
    pl_model = LitZSGCADA.load_from_checkpoint(f"/shared/lorenzo/checkpoints/{cf}", config=config)
    model = pl_model.model
    model.eval()
    model.to(config.device)
    

    test_acc_so = torchmetrics.Accuracy()
    test_acc_so = test_acc_so.to(config.device)

    test_acc_translated_only = torchmetrics.Accuracy()
    test_acc_translated_only = test_acc_translated_only.to(config.device)

    test_acc_post = torchmetrics.Accuracy()
    test_acc_post = test_acc_post.to(config.device)

    for batch in tqdm(dataloader):
        inputs_mnist_c, targets = batch

        inputs_mnist_c = inputs_mnist_c.to(config.device)
        targets = targets.to(config.device)
        translated = model.translate(inputs_mnist_c).detach()

        # plt.imshow(inputs_mnist_c.squeeze(0).permute(1,2,0))
        # plt.show()

        # plt.imshow(translated.squeeze(0).permute(1,2,0))
        # plt.show()

        outputs_mnist_c = model.pretrained(inputs_mnist_c)
        outputs_translated = model.pretrained(translated)

        combined_outputs = []
        for output_mnist_c, output_translated in zip(outputs_mnist_c, outputs_translated):
            logit_diff = F.l1_loss(
                output_mnist_c,
                output_translated,
            )

            # if logit_diff < 2.0:
            #     combined_outputs.append(output_mnist_c)
            # else:
            #     combined_outputs.append(output_translated + output_mnist_c)
            combined_outputs.append((1-lambda_pp) * torch.sigmoid(output_translated) + lambda_pp * torch.sigmoid(output_mnist_c))
            # combined_outputs.append(torch.sigmoid(output_mnist_c))
        combined_outputs = torch.stack(combined_outputs, dim=0)
        test_acc_so(outputs_mnist_c.argmax(dim=-1), targets)
        test_acc_translated_only(outputs_translated.argmax(dim=-1), targets)
        test_acc_post(combined_outputs.argmax(dim=-1), targets)
    so = test_acc_so.compute()
    post = test_acc_post.compute()
    print(so, post)
    sos.append(so)
    posts.append(post)
sum(sos) / len(sos), sum(posts) / len(posts)

100%|██████████| 156/156 [00:03<00:00, 50.40it/s]


tensor(0.5440, device='cuda:1') tensor(0.5342, device='cuda:1')


100%|██████████| 156/156 [00:02<00:00, 55.82it/s]


tensor(0.4989, device='cuda:1') tensor(0.5032, device='cuda:1')


100%|██████████| 156/156 [00:02<00:00, 55.94it/s]


tensor(0.5216, device='cuda:1') tensor(0.5277, device='cuda:1')


100%|██████████| 156/156 [00:02<00:00, 55.55it/s]

tensor(0.5162, device='cuda:1') tensor(0.5279, device='cuda:1')





(tensor(0.5202, device='cuda:1'), tensor(0.5233, device='cuda:1'))

In [6]:
# inputs_mnist_c, _ = dataset_mnist_c[0]