In [1]:
from torch.utils.data import DataLoader
from tqdm import tqdm

import torch

import pytorch_lightning as pl
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from dataset.digit import ChanDup, CENDataset
from pl_module.gcada import LitGCADA

pl.seed_everything(10)

Global seed set to 10


10

In [2]:
class GCADAConfig:
    lambda_noise: None
    device="cuda:1"
    
    root="/shared/lorenzo/mnist-zsda/MNIST_G"
    root_tgt="/shared/lorenzo/mnist-zsda/MNIST_C"
    
    model_name="gcada"
    backbone_name="lenet"
    
    root_tgt_train="/shared/lorenzo/mnist-zsda/FashionMNIST_C"
    num_blocks=9
    hidden_dim_dsc=64
    lambda_idt=1.0
    lambda_sem=0.0
    transformation="rotate"
    beta1=0.5
    pretrained="/root/dezsda/checkpoints/timm-lenet-epoch=03-task=digit-val_loss=0.0297.ckpt"
    fix_block_up=False
    lambda_sem_idt=10.0
    sem_idt_per_epoch=1

    lr = 1e-3
    optimizer = "adam"
    batch_size = 64
    max_epochs = 50
    grad_accum = 1
    es_patience = None
    task="digit"
    img_size = 28
    fold_no = 0
    num_workers = 8
    channels = 3
    logger = True
    seed = 42
    project = "csi-har"
    checkpoint_dir = "/root/dezsda/checkpoints"
    gpus = 1
    num_classes=10
    num_test_sets=4

    assert root.split("/")[-1].split("_")[0] == root_tgt.split("/")[-1].split("_")[0]
    translated_dir=f"/shared/lorenzo/mnist-zsda/{root.split('/')[-1].split('_')[0]}_{root.split('/')[-1].split('_')[-1]}{root_tgt.split('/')[-1].split('_')[-1]}"
    
config = GCADAConfig()

In [3]:
pl_model = LitGCADA.load_from_checkpoint("./checkpoints/gcada-lenet-epoch=00-task=digit-val_loss=0.0000.ckpt", config=config)
model = pl_model.model
model.eval()

transform = transforms.Compose([
    # transforms.Resize(32),
    transforms.CenterCrop((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
    ChanDup(),
])

dataset = datasets.MNIST(root=config.root, train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=config.batch_size,
                        shuffle=True, pin_memory=True, 
                        num_workers=config.num_workers, drop_last=True)

In [4]:
model = model.to(config.device)
translated_inputs_train = None
translated_targets_train = None
for batch in tqdm(dataloader):
    inputs_src, targets_src = batch
    inputs_src = inputs_src.to(config.device)
    targets_src = targets_src.to(config.device)

    outputs = model.translate(inputs_src, random_noise=0.1)

    if translated_inputs_train is not None:
        translated_inputs_train = torch.cat([outputs.detach(), translated_inputs_train], dim=0)
        translated_targets_train = torch.cat([targets_src.detach(), translated_targets_train], dim=0)
    else:
        translated_inputs_train = outputs
        translated_targets_train = targets_src

    del outputs
    del targets_src

100%|██████████| 937/937 [00:10<00:00, 89.76it/s]


In [5]:
dataset = CENDataset(root=config.root_tgt, train=False, transform=transform)
dataloader = DataLoader(dataset, batch_size=config.batch_size,
                        shuffle=True, pin_memory=True, 
                        num_workers=config.num_workers, drop_last=True)

In [6]:
model = model.to(config.device)
translated_inputs_test = None
translated_targets_test = None
for batch in tqdm(dataloader):
    inputs_src, targets_src = batch
    inputs_src = inputs_src.to(config.device)
    targets_src = targets_src.to(config.device)

    outputs = model.translate(inputs_src)

    if translated_inputs_test is not None:
        translated_inputs_test = torch.cat([outputs.detach(), translated_inputs_test], dim=0)
        translated_targets_test = torch.cat([targets_src.detach(), translated_targets_test], dim=0)
    else:
        translated_inputs_test = outputs
        translated_targets_test = targets_src

    del outputs
    del targets_src

100%|██████████| 156/156 [00:01<00:00, 88.97it/s] 


In [7]:
# idx = 2
# plt.imshow(translated_inputs[idx].permute(1,2,0).cpu().detach())
# print(translated_targets[idx])

In [8]:
data = {
    "train": {
        "data": translated_inputs_train.cpu().detach().numpy(), 
        "targets": translated_targets_train.cpu().detach().numpy(),
    },
    "test": {
        "data": translated_inputs_test.cpu().detach().numpy(), 
        "targets": translated_targets_test.cpu().detach().numpy(),
    }
}

In [9]:
import pickle
import os

os.makedirs(config.translated_dir, exist_ok=True)

with open(os.path.join(config.translated_dir, "data.pkl"), "wb") as f:
    pickle.dump(data, f)

In [10]:
"completed"

'completed'