In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
# Python
from pathlib import Path
import os
import warnings
import math
import datetime
import time
import pickle
warnings.filterwarnings('ignore')

# Extern
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from dotted_dict import DottedDict
import matplotlib.pyplot as plt
from tqdm import tqdm
import pprint
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
import torchvision.transforms as transforms
import pprint

# Local
from UNet.unet import UNet
from csprites.datasets import SegmentationDataset
import utils
from optimizer import get_optimizer

In [None]:
config = {
    'device': 'cuda',
    'cuda_visible_devices': '1',
    'p_data': "/mnt/data/csprites/single_csprites_64x64_n7_c128_a32_p10_s3_bg_inf_random_function_100000",
    'target_variable': 'shape',
    'batch_size': 512,
    'num_workers': 6,
    'num_epochs': 10,
    'freqs': {
        'ckpt': 50,         # epochs
        'eval': 1,       # epochs
        'show': 1,
    },
    'n_vis': 16,
    'model': {
        'chs_tail': [3, 8],
        'chs_down': [8, 16, 32, 64, 128, 256],
        'chs_up': [256, 128, 64, 32, 16, 8],
        'chs_head': [8, 1],
        'n_conv_blocks': 1
    },
    'optimizer': 'adam',
    'optimizer_args': {
        'lr': 1e-3,
        'weight_decay': 1e-6
    },
    'p_ckpts': "ckpts",
    'p_model': "model_{}.ckpt",
    'p_stats': "stats.pkl",
    'p_config': 'config.pkl',
    'p_R_train': 'R_train.npy',
    'p_R_valid': 'R_valid.npy',
    'p_Y_valid': 'Y_valid.npy',
    'p_Y_train': 'Y_train.npy',
    'linprob': {
        'optimizer': 'adam',
        'optimizer_args': {
            'lr': 0.001,
            'weight_decay': 1e-6
        },
        'n_hid': 0,
        'd_hid': 1024,
        'num_epochs': 1
    }
    
}
p_base = Path("/mnt/experiments/csprites") / Path(config["p_data"]).name / "tmp"
ts = time.time()
st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
#
config["p_experiment"] = str(p_base / "UNet_sup_{}".format(st))
config = DottedDict(config)
pprint.pprint(config)

In [None]:
# TORCH SETTINGS
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_visible_devices
device = torch.device(config.device)

In [None]:
p_ds_config = Path(config.p_data) / "config.pkl"

with open(p_ds_config, "rb") as file:
    ds_config = pickle.load(file)

In [None]:
norm_transform = utils.normalize_transform(
    ds_config["means"],
    ds_config["stds"])
inverse_norm_transform = utils.inverse_normalize_transform(
    ds_config["means"],
    ds_config["stds"]
)

transform_train = transforms.Compose([
                transforms.ToTensor(),
                norm_transform
            ])

transform_segm = transforms.Compose([
    transforms.ToTensor(),
    lambda x: (x > 0).float()
])

In [None]:
# TRAIN
ds_train = SegmentationDataset(
    p_data = config.p_data,
    transform=transform_train,
    target_transform=None,
    seg_transform=transform_segm,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)
# VALID
ds_valid = SegmentationDataset(
    p_data = config.p_data,
    transform=transform_train,
    target_transform=None,
    seg_transform=transform_segm,
    split="valid"
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)

# Visualize

In [None]:
n_vis = 9

In [None]:
x,y,z = next(iter(dl_train))

In [None]:
x = x[:n_vis]
z = z[:n_vis]

In [None]:
print(z.shape, z.dtype, z.min(), z.max())

In [None]:
x = inverse_norm_transform(x)

In [None]:
grid_img = torchvision.utils.make_grid(x, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
grid_img = torchvision.utils.make_grid(z, nrow=int(np.sqrt(n_vis)))

In [None]:
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

# Model

In [None]:
model = UNet(
    config.model.chs_tail,
    config.model.chs_down,
    config.model.chs_up,
    config.model.chs_head,
    config.model.n_conv_blocks
)
#
if torch.cuda.device_count() > 1 and device != "cpu":
    print("Using {} gpus!".format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)
    model.backbone = model.module.backbone
elif device != "cpu":
    print("Using 1 GPU!")
else:
    print("Using CPU!")
n_params = utils.count_parameters(model)
model = model.to(device)
print("#Params:", n_params)

In [None]:
optimizer = get_optimizer(config.optimizer, model.parameters(), config.optimizer_args)
criterion = nn.MSELoss()

# Test Pipeline

In [None]:
model.eval()
n_vis = config.n_vis
x_vis,_,z_ori_vis = next(iter(dl_valid))
#
x_vis = x_vis[:n_vis]
z_ori_vis = z_ori_vis[:n_vis]
#
z_pre_vis = model(x_vis.to(device)).detach().cpu()

In [None]:
print(x_vis.shape, x_vis.dtype)
print(z_ori_vis.shape, z_ori_vis.dtype)
print(z_pre_vis.shape, z_pre_vis.dtype)

In [None]:
def visualize(x, z_pre, z_ori, n_vis=4):
    x = inverse_norm_transform(x).cpu()[:n_vis]
    z_pre = z_pre.detach().cpu()[:n_vis]
    z_ori = z_ori.detach().cpu()[:n_vis]
    #
    img_x = torchvision.utils.make_grid(x, ncol=1).permute(1, 2, 0)
    img_z_ori = torchvision.utils.make_grid(z_ori, ncol=1).permute(1, 2, 0)
    img_z_pre = torchvision.utils.make_grid(z_pre, ncol=1).permute(1, 2, 0)
    img_z = torch.cat([img_z_ori, img_z_pre], dim=0)
    #
    plt.imshow(img_x)
    plt.show()
    
    #
    plt.imshow(img_z)
    plt.show()

In [None]:
visualize(x_vis, z_pre_vis, z_ori_vis, n_vis)

## Train

In [None]:
stats = {
    'train': {
        'loss': [],
        'epoch': [],
    },
    'valid': {
        'loss': [],
        'epoch': [],
    }
}
stats = DottedDict(stats)

In [None]:
global_step = 0
for epoch_idx in range(1, config.num_epochs + 1, 1):
    model.train()
    epoch_step = 0
    epoch_loss = 0
    desc = "Epoch [{:3}/{:3}] {}:".format(epoch_idx, config.num_epochs, 'train')
    pbar = tqdm(dl_train, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
    for x, _, z in pbar:
        x = x.to(device)
        z = z.to(device)
        #
        for param in model.parameters():
            param.grad = None
        z_pred = model(x)
        loss = criterion(z_pred, z)
        loss.backward()
        optimizer.step()
        #
        epoch_loss += loss.item()
        epoch_step += 1
        global_step += 1
        #
        pbar.set_postfix({"loss": loss.item()})
    stats.train.loss.append(epoch_loss / epoch_step)
    stats.train.epoch.append(epoch_idx)
    
    if epoch_idx % config.freqs.eval == 0:
        model.eval()
        epoch_loss = 0
        epoch_step = 0
        desc = "Epoch [{:3}/{:3}] {}:".format(epoch_idx, config.num_epochs, 'valid')
        pbar = tqdm(dl_valid, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
        for x, _, z in pbar:
            x = x.to(device)
            z = z.to(device)
            with torch.no_grad():
                z_pred = model(x)
            loss = criterion(z_pred, z)

            epoch_loss += loss.item()
            epoch_step += 1
            #
            pbar.set_postfix({"loss": loss.item()})
        print("   Loss: {:.4f}".format(epoch_loss / epoch_step))
        stats.valid.epoch.append(epoch_idx)
        stats.valid.loss.append(epoch_loss / epoch_step)
    
    if epoch_idx % config.freqs.show == 0:
        model.eval()
        with torch.no_grad():
            z_pre_vis = model(x_vis.to(device)).detach().cpu()
        visualize(x_vis, z_pre_vis, z_ori_vis, n_vis)

# Plot

In [None]:
# plot losses
plt.plot(stats.train.epoch, stats.train.loss, label="train")
plt.plot(stats.valid.epoch, stats.valid.loss, label="valid")
plt.legend()
plt.show()

# Test

In [None]:
model.eval()
x2 = torch.rand((128, 3, 64, 64))
x1 = torch.rand((128, 3, 64, 64))
#
with torch.no_grad():
    z1 = model(x1.to(device)).cpu()
    z2 = model(x2.to(device)).cpu()

In [None]:
criterion(z1, z2)

In [None]:
torch.mean((z1 - z2)**2)