# Experiments for finding barycenter with kernel cost on MNIST 0 and 1 (DATA SPACE)

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm
import itertools
import wandb
import os
import sys

sys.path.append("../../")
from src.utils import Config, weights_init_D, freeze, unfreeze
from src.data import DatasetSampler
from src.models import ResNet_D, UNet
# from src.train import train
from src.train_kernel import train_kernel_data
from src.cost import strong_cost

%load_ext autoreload
%autoreload 2

## 1. Config

In [3]:
BATCH_SIZE = 128
NUM_EPOCHS = 10000
INNER_ITERATIONS = 10
LR = 2e-4
GPU_DEVICE = 0
KREG = 'energy' # 'mse', 'energy', 'gaussian', 'laplacian'
GAMMA = 1.0
Z_STD = 0.1

In [4]:
CONFIG = Config()

CONFIG.DIM = 1*32*32
CONFIG.CLASSES = [0,1]
CONFIG.K = 2# amount of distributions
CONFIG.LAMBDAS = [0.5,0.5]
CONFIG.IMG_SIZE = 32
CONFIG.NC=1
CONFIG.ZC=1
CONFIG.DATASET_PATH = '../../data/MNIST'  
CONFIG.DATASET = 'mnist'
CONFIG.Z_STD = Z_STD
CONFIG.RESNET_ENCODER_LATENT = 4
# CONFIG.FLAG_LATENT = False
# CONFIG.FLAG_LATENT_CRITIC = False

# kernel regularization
CONFIG.KREG = KREG
CONFIG.GAMMA = GAMMA

CONFIG.LR = LR
# CONFIG.CLIP_GRADS_NORM = False
CONFIG.BETAS = (0.2, 0.99)
CONFIG.BATCH_SIZE = BATCH_SIZE
CONFIG.NUM_EPOCHS = NUM_EPOCHS
CONFIG.INNER_ITERATIONS = INNER_ITERATIONS # поменьше можно сделать

CONFIG.GPU_DEVICE = GPU_DEVICE
assert torch.cuda.is_available()
CONFIG.DEVICE = f'cuda:{CONFIG.GPU_DEVICE}'

## 2. Data for experiment

In [5]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((CONFIG.IMG_SIZE, CONFIG.IMG_SIZE)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: 2 * x - 1)
])

data_samplers = []
 

for k in range(CONFIG.K):
    dataset = torchvision.datasets.MNIST(root=CONFIG.DATASET_PATH,
                                         download=True, 
                                         transform=transform)
    
    idx = [t == CONFIG.CLASSES[k] for t in dataset.targets]
    dataset.targets, dataset.data = np.array(dataset.targets)[idx], torch.tensor(dataset.data)[idx] 
    data_samplers.append(DatasetSampler(dataset,flag_label=True,num_workers=1, batch_size=256))

## 3. Potentials and maps

In [6]:
if CONFIG.K == 2:
    
    f = ResNet_D(size=CONFIG.IMG_SIZE,
              nc=CONFIG.NC,
              nfilter=64, 
              nfilter_max=512, 
              res_ratio=0.1).to(CONFIG.DEVICE)
    
    weights_init_D(f)
    nets_for_pot = [f]
     
    
    nets_for_pot_opt = torch.optim.Adam( nets_for_pot[0].parameters() ,
                               CONFIG.LR, betas=CONFIG.BETAS)
                   
else: 
    def new_f():
        f = ResNet_D(size=CONFIG.IMG_SIZE,
                  nc=CONFIG.NC,
                  nfilter=64, 
                  nfilter_max=512, 
                  res_ratio=0.1).to(CONFIG.DEVICE)
        weights_init_D(f)
        return f

    nets_for_pot = [new_f() for i in range(CONFIG.K)]
    
    param_nets = [net.parameters() for net in nets_for_pot]
    nets_for_pot_opt = torch.optim.Adam(itertools.chain(*param_nets),
                               CONFIG.LR, betas=CONFIG.BETAS)
 

In [7]:
maps = [UNet(n_channels=CONFIG.NC + CONFIG.ZC, n_classes=CONFIG.NC,
             base_factor=48 , bilinear=True).to(CONFIG.DEVICE) for i in range(CONFIG.K)]

param_maps = [mp.parameters() for mp in maps]
maps_opt = torch.optim.Adam(itertools.chain(*param_maps),
                               CONFIG.LR, betas=CONFIG.BETAS)

## 4. Train

In [8]:
WANDB_MODE = "disabled" # 'online', "disabled"

In [9]:
wandb.init(project="BNOT" ,
           name=f"mnist01_kernel_test_BS_{CONFIG.BATCH_SIZE}_EP_{CONFIG.NUM_EPOCHS}_INN_{CONFIG.INNER_ITERATIONS}_LR_{CONFIG.LR}" ,
           config=CONFIG, mode=WANDB_MODE, reinit=True)



In [11]:
train_kernel_data(nets_for_pot, 
          maps,
          nets_for_pot_opt,
          maps_opt,
          data_samplers,
          CONFIG)

In [12]:
wandb.finish()