In [1]:
import glob
import os
import time
import csv

from stargan2.solver_multi_GPU import Solver

import numpy as np

import matplotlib.pyplot as plt
import torch
import torchvision.utils as vutils
from torch.utils.data import DataLoader, TensorDataset

import torchvision.models.resnet as resnet

from torch.backends import cudnn
cudnn.benchmark = True

In [2]:
import torchvision.datasets as datasets
from itertools import groupby
import random
from torch.utils.data.sampler import WeightedRandomSampler

class ReferenceDataset(datasets.DatasetFolder):
    def __init__(
            self,
            root,
            transform = None,
            target_transform = None,
            loader = datasets.folder.default_loader,
            is_valid_file = None,
    ):
        super(ReferenceDataset, self).__init__(root, loader, 
                                          datasets.folder.IMG_EXTENSIONS if is_valid_file is None else None,
                                          transform=transform,
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)
        
        # group samples by label
        groupped = {k: list(v) for k, v in  groupby(self.samples, lambda x: x[1])}
        
        # crate reference images
        references = []
        targets = []
        for domain, values in groupped.items():
            # unzip samples and targets
            samples, _ = zip(*values)
            # shuffle second reference images
            samples2 = random.sample(samples, len(samples))
            # repeat labels
            labels = [domain]*len(samples)
            targets+= labels
            references+= list(zip(samples, samples2, labels))
        # override samples
        self.samples = references
        # override targets to make sure that the samples have corresponding labels
        self.targets = targets
        
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (ref sample 1, ref sample 2, ref target)
        """
        ref_path1, ref_path2, ref_target = self.samples[index]

        ref1 = self.loader(ref_path1)
        ref2 = self.loader(ref_path2)
        
        if self.transform is not None:
            ref1 = self.transform(ref1)
            ref2 = self.transform(ref2)
            
        if self.target_transform is not None:
            ref_target = self.target_transform(ref_target)

        return ref1, ref2, ref_target

def make_balanced_sampler(labels):
    class_counts = np.bincount(labels)
    class_weights = 1. / class_counts
    weights = class_weights[labels]
    WeightedRandomSampler(weights, len(weights))

In [3]:
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

img_size = 128
batch_size = 8
num_workers = 4

transform = transforms.Compose([
    transforms.Resize([img_size, img_size]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])

ds = ImageFolder('data/mai', transform)
ds_ref = ReferenceDataset('data/mai', transform)

sampler = make_balanced_sampler(ds.targets)
loader = DataLoader(dataset=ds,
                    batch_size=batch_size,
                    sampler=sampler,
                    shuffle=True,
                    num_workers=num_workers,
                    pin_memory=True,
                    drop_last=True)

sampler_ref = make_balanced_sampler(ds_ref.targets)
loader_ref = DataLoader(dataset=ds_ref,
                        batch_size=batch_size,
                        sampler=sampler_ref,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True)


In [4]:
domains, domains_mapping = ds._find_classes(ds.root)
n_domains = len(domains)
print('n_domains:', n_domains)
domains_mapping

n_domains: 2


{'wsi1_tiles': 0, 'wsi2_tiles': 1}

In [5]:
solver = Solver('StarGAN2_MAI_128', 'r1_ds1_1', img_size, n_domains=2, lambda_ds=.1)

In [7]:
solver.load_models('results/StarGAN2_MAI_128_ds1/')

In [None]:
solver.train(300000, loader, loader_ref, val=ds)

[34m[1mwandb[0m: Currently logged in as: [33marray[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.29 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




In [2]:
from torchsummary import summary
from stargan2.model import Discriminator, Generator, MappingNetwork, StyleEncoder

In [2]:
summary(Generator(128, 64).cuda(), [(3, 128, 128), ((64,))], batch_size=8)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [8, 128, 128, 128]           3,584
            Conv2d-2         [8, 256, 128, 128]          32,768
    InstanceNorm2d-3         [8, 128, 128, 128]             256
         LeakyReLU-4         [8, 128, 128, 128]               0
         LeakyReLU-5         [8, 128, 128, 128]               0
         LeakyReLU-6         [8, 128, 128, 128]               0
         LeakyReLU-7         [8, 128, 128, 128]               0
         LeakyReLU-8         [8, 128, 128, 128]               0
            Conv2d-9         [8, 128, 128, 128]         147,584
   InstanceNorm2d-10           [8, 128, 64, 64]             256
        LeakyReLU-11           [8, 128, 64, 64]               0
        LeakyReLU-12           [8, 128, 64, 64]               0
        LeakyReLU-13           [8, 128, 64, 64]               0
        LeakyReLU-14           [8, 128,

In [2]:
summary(Discriminator(128, 2).cuda(), [(3, 128, 128), ((1,))], batch_size=8)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [8, 128, 128, 128]           3,584
            Conv2d-2         [8, 256, 128, 128]          32,768
         LeakyReLU-3         [8, 128, 128, 128]               0
         LeakyReLU-4         [8, 128, 128, 128]               0
         LeakyReLU-5         [8, 128, 128, 128]               0
         LeakyReLU-6         [8, 128, 128, 128]               0
         LeakyReLU-7         [8, 128, 128, 128]               0
            Conv2d-8         [8, 128, 128, 128]         147,584
         LeakyReLU-9           [8, 128, 64, 64]               0
        LeakyReLU-10           [8, 128, 64, 64]               0
        LeakyReLU-11           [8, 128, 64, 64]               0
        LeakyReLU-12           [8, 128, 64, 64]               0
        LeakyReLU-13           [8, 128, 64, 64]               0
           Conv2d-14           [8, 256,

In [2]:
summary(MappingNetwork(128).cuda(), [(3, 128, 128), ((1,))], batch_size=8)

tensor([0, 1], device='cuda:0') tensor([0, 0], device='cuda:0')
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [8, 3, 128, 512]          66,048
              ReLU-2           [8, 3, 128, 512]               0
            Linear-3           [8, 3, 128, 512]         262,656
              ReLU-4           [8, 3, 128, 512]               0
            Linear-5           [8, 3, 128, 512]         262,656
              ReLU-6           [8, 3, 128, 512]               0
            Linear-7           [8, 3, 128, 512]         262,656
              ReLU-8           [8, 3, 128, 512]               0
            Linear-9           [8, 3, 128, 512]         262,656
             ReLU-10           [8, 3, 128, 512]               0
           Linear-11           [8, 3, 128, 512]         262,656
             ReLU-12           [8, 3, 128, 512]               0
           Linear-13           [8, 3, 1

In [3]:
summary(StyleEncoder(128).cuda(), [(3, 128, 128), ((1,))], batch_size=8)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [8, 128, 128, 128]           3,584
            Conv2d-2         [8, 256, 128, 128]          32,768
         LeakyReLU-3         [8, 128, 128, 128]               0
         LeakyReLU-4         [8, 128, 128, 128]               0
         LeakyReLU-5         [8, 128, 128, 128]               0
         LeakyReLU-6         [8, 128, 128, 128]               0
         LeakyReLU-7         [8, 128, 128, 128]               0
            Conv2d-8         [8, 128, 128, 128]         147,584
         LeakyReLU-9           [8, 128, 64, 64]               0
        LeakyReLU-10           [8, 128, 64, 64]               0
        LeakyReLU-11           [8, 128, 64, 64]               0
        LeakyReLU-12           [8, 128, 64, 64]               0
        LeakyReLU-13           [8, 128, 64, 64]               0
           Conv2d-14           [8, 256,