In [1]:
import glob
import os
import time
import csv

# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 
# os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

from stargan2.solver_multi_GPU import Solver

import numpy as np

import matplotlib.pyplot as plt
import torch
import torchvision.utils as vutils
from torch.utils.data import DataLoader, TensorDataset

import torchvision.models.resnet as resnet

from torch.backends import cudnn
cudnn.benchmark = True

In [2]:
import torchvision.datasets as datasets
from itertools import groupby
import random
from torch.utils.data.sampler import WeightedRandomSampler

class BinRefDataset(torch.utils.data.Dataset):
    def __init__(self, X0, X1):
        super().__init__()
        
        self.X = torch.from_numpy(np.concatenate([X0, X1]))
        
        x0_l = len(X0)
        x1_l = len(X1)

        self.y = torch.tensor([0]*x0_l + [1]*x1_l)
        
        idx1 = list(range(x0_l))
        idx2 = random.sample(idx1, x0_l)
        self.idx = list(zip(idx1, idx2))
        
        idx1 = range(x0_l, x0_l+x1_l)
        idx2 = random.sample(idx1, x1_l)
        self.idx+= list(zip(idx1, idx2))
        
    def __getitem__(self, index):
        idx1, idx2 = self.idx[index]
        return self.X[idx1], self.X[idx2], self.y[idx1]
    
    def __len__(self):
        return len(self.X)

# def make_balanced_sampler(labels):
#     class_counts = np.bincount(labels)
#     class_weights = 1. / class_counts
#     weights = class_weights[labels]
#     WeightedRandomSampler(weights, len(weights))

In [3]:
batch_size = 48
num_workers = 2

X0_train = (np.load('X0_train_clean_48.npy') - .5) / .5
X1_train = (np.load('X1_train_clean_48.npy') - .5) / .5

X_train = np.concatenate([X0_train, X1_train])
y_train = torch.tensor([0]*len(X0_train)+[1]*len(X1_train))

trainloader = DataLoader(TensorDataset(torch.from_numpy(X_train),y_train), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)

ds_ref = BinRefDataset(X0_train, X1_train)

loader_ref = DataLoader(dataset=ds_ref,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True)



In [4]:
class ResNet(resnet.ResNet):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self._forward_impl(x))

In [5]:
X0_val = (np.load('X0_val_clean_48.npy') - .5) / .5
X1_val = (np.load('X1_val_clean_48.npy') - .5) / .5


X_val = torch.from_numpy(np.concatenate([X0_val, X1_val]))
y_val = torch.from_numpy(np.concatenate([np.load('y0_val_clean_48.npy'), np.load('y1_val_clean_48.npy')]))
# define target labels for each style 
y_s_val = torch.tensor([1]*len(X0_val)+[0]*len(X1_val))

valloader = DataLoader(TensorDataset(X_val, y_val, y_s_val), batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True)

Clf = ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=1)
Clf.load_state_dict(torch.load('results/clf_resnet18_48/best_model.pth'));

In [6]:
solver = Solver('StarGAN2_48', 'mse_ds0', 48, n_domains=2, lambda_ds=0)

In [None]:
solver.train(100000, trainloader, loader_ref, Clf, valloader)

[34m[1mwandb[0m: Currently logged in as: [33marray[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.18 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




In [5]:
from torchsummary import summary
from stargan2.model import Discriminator, Generator, MappingNetwork, StyleEncoder

In [2]:
summary(Generator(48, 64).cuda(), [(3, 48, 48), ((64,))], batch_size=48)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [48, 341, 48, 48]           9,548
            Conv2d-2          [48, 512, 48, 48]         174,592
    InstanceNorm2d-3          [48, 341, 48, 48]             682
         LeakyReLU-4          [48, 341, 48, 48]               0
         LeakyReLU-5          [48, 341, 48, 48]               0
         LeakyReLU-6          [48, 341, 48, 48]               0
            Conv2d-7          [48, 341, 48, 48]       1,046,870
    InstanceNorm2d-8          [48, 341, 24, 24]             682
         LeakyReLU-9          [48, 341, 24, 24]               0
        LeakyReLU-10          [48, 341, 24, 24]               0
        LeakyReLU-11          [48, 341, 24, 24]               0
           Conv2d-12          [48, 512, 24, 24]       1,571,840
           ResBlk-13          [48, 512, 24, 24]               0
   InstanceNorm2d-14          [48, 512,

In [2]:
summary(Discriminator(48, 2).cuda(), [(3, 48, 48), ((1,))], batch_size=48)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [48, 341, 48, 48]           9,548
            Conv2d-2          [48, 512, 48, 48]         174,592
         LeakyReLU-3          [48, 341, 48, 48]               0
         LeakyReLU-4          [48, 341, 48, 48]               0
         LeakyReLU-5          [48, 341, 48, 48]               0
            Conv2d-6          [48, 341, 48, 48]       1,046,870
         LeakyReLU-7          [48, 341, 24, 24]               0
         LeakyReLU-8          [48, 341, 24, 24]               0
         LeakyReLU-9          [48, 341, 24, 24]               0
           Conv2d-10          [48, 512, 24, 24]       1,571,840
           ResBlk-11          [48, 512, 24, 24]               0
        LeakyReLU-12          [48, 512, 24, 24]               0
        LeakyReLU-13          [48, 512, 24, 24]               0
        LeakyReLU-14          [48, 512,

In [6]:
summary(MappingNetwork(48).cuda(), [(3, 48, 48), ((1,))], batch_size=48)

tensor([0, 1], device='cuda:0') tensor([0, 0], device='cuda:0')
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [48, 3, 48, 512]          25,088
              ReLU-2           [48, 3, 48, 512]               0
            Linear-3           [48, 3, 48, 512]         262,656
              ReLU-4           [48, 3, 48, 512]               0
            Linear-5           [48, 3, 48, 512]         262,656
              ReLU-6           [48, 3, 48, 512]               0
            Linear-7           [48, 3, 48, 512]         262,656
              ReLU-8           [48, 3, 48, 512]               0
            Linear-9           [48, 3, 48, 512]         262,656
             ReLU-10           [48, 3, 48, 512]               0
           Linear-11           [48, 3, 48, 512]         262,656
             ReLU-12           [48, 3, 48, 512]               0
           Linear-13           [48, 3, 

In [7]:
summary(StyleEncoder(48).cuda(), [(3, 48, 48), ((1,))], batch_size=48)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [48, 341, 48, 48]           9,548
            Conv2d-2          [48, 512, 48, 48]         174,592
         LeakyReLU-3          [48, 341, 48, 48]               0
         LeakyReLU-4          [48, 341, 48, 48]               0
         LeakyReLU-5          [48, 341, 48, 48]               0
            Conv2d-6          [48, 341, 48, 48]       1,046,870
         LeakyReLU-7          [48, 341, 24, 24]               0
         LeakyReLU-8          [48, 341, 24, 24]               0
         LeakyReLU-9          [48, 341, 24, 24]               0
           Conv2d-10          [48, 512, 24, 24]       1,571,840
           ResBlk-11          [48, 512, 24, 24]               0
        LeakyReLU-12          [48, 512, 24, 24]               0
        LeakyReLU-13          [48, 512, 24, 24]               0
        LeakyReLU-14          [48, 512,