In [7]:
%matplotlib inline
import matplotlib.pyplot as plt

from torchvision.utils import make_grid
from torch import nn
from pytorch_lightning import Trainer
from preprocessing.image_transform import ImageTransform
from preprocessing.seg_transforms import SegImageTransform
from datasets.monet import MonetDataModule
from datasets.agri import AgriDataModule
from systems.cycle_gan_system import CycleGANSystem
from models.generators import CycleGANGenerator
from models.discriminators import CycleGANDiscriminator
from utils.weight_initializer import init_weights
from datasets.gogoll import GogollDataModule
import os, glob, random

from models.unet_light_semseg import UnetLight
from datasets.generated import GeneratedDataModule
from datasets.mixed import MixedDataModule
from datasets.mixed import MixedDataset
from datasets.mixedCV import MixedDataModuleCV
from datasets.source import SourceDataModule

from systems.gogoll_system import GogollSystem

from sklearn.model_selection import KFold
from models.unet_light_semseg import UnetLight


data_dir = './data'
domain = "domainB"

In [3]:
lr = {
        "G": 0.0002,
        "D": 0.0002,
        "seg_s": 0.0002,
        "seg_t": 0.0002,
    }

In [4]:
# Sub-Models  -----------------------------------------------------------------
seg_net_s = UnetLight()
seg_net_t = UnetLight()
G_basestyle = CycleGANGenerator(filter=32)
G_stylebase = CycleGANGenerator(filter=32)
D_base = CycleGANDiscriminator(filter=32)
D_style = CycleGANDiscriminator(filter=32)

In [5]:
gogoll_net_config = {
        "G_s2t": G_basestyle,
        "G_t2s": G_stylebase,
        "D_source": D_base,
        "D_target": D_style,
        "seg_s": seg_net_s,
        "seg_t": seg_net_t,
        "lr": lr,
        "reconstr_w": 10,
        "id_w": 2,
        "seg_w": 0.8
    }
main_system = GogollSystem(**gogoll_net_config)

In [6]:
transform = SegImageTransform(img_size=256)
batch_size = 8

# Source domain datamodule
source_dm = SourceDataModule(data_dir, transform, batch_size=1, max_imgs=200)
# Generated images datamodule
generated_dm = GeneratedDataModule(main_system.G_s2t, data_dir, transform, batch_size=1, max_imgs=200)
# Mix both datamodules
mixed_dm = MixedDataModuleCV(
    source_dm,
    generated_dm,
    batch_size=batch_size
)
mixed_dm.prepare_data()
mixed_dm.setup()
dataloader = mixed_dm.train_dataloader() # get the loader that returns us data
batch = next(iter(dataloader)) # ask for the next batch of data
base, style = (batch['source'], batch['source_segmentation'])

print('Input Shape {}, {}'.format(base.size(), style.size()))

  "Argument interpolation should be of type InterpolationMode instead of int. "


Input Shape torch.Size([8, 3, 256, 256]), torch.Size([8, 256, 256])


In [8]:
model = UnetLight()

In [9]:
r = model(base)

In [10]:
r.shape

torch.Size([8, 3, 256, 256])

In [15]:
import torch
soft = nn.Softmax(dim=1)

In [16]:
r2 = torch.argmax(r, dim=1)
r2.shape

torch.Size([8, 256, 256])

In [None]:
dms = mixed_dm.get_datamoduels()

In [None]:
cv_splitter = KFold(n_splits=5, random_state=None, shuffle=False)

In [None]:
all_datasets = []
for i in range(len(dms)):
    all_datasets.append(dms[i].full_dataset)

In [None]:
from torch.utils.data import DataLoader, Dataset, ConcatDataset

In [None]:
conc_datasets = ConcatDataset(all_datasets)

In [None]:
for train_index, test_index in self.cv_splitter.split(conc_datasets):
    

In [None]:
train_dataset = MixedDataset([x.dataset for x in train_loaders])
val_dataset = MixedDataset([x.dataset for x in val_loaders])
test_dataset = MixedDataset([x.dataset for x in test_loaders])

In [None]:
transform = SegImageTransform(img_size=256)
batch_size = 8

# Source domain datamodule
source_dm = SourceDataModule(data_dir, transform, batch_size=1, max_imgs=200)
# Generated images datamodule
generated_dm = GeneratedDataModule(main_system.G_s2t, data_dir, transform, batch_size=1, max_imgs=200)
# Mix both datamodules
mixed_dm = MixedDataModule(
    source_dm,
    generated_dm,
    batch_size=batch_size
)

dm.prepare_data() # call first initialization function before we start asking for data
dm.setup() # call second initialization function before we start asking for data

dataloader = dm.train_dataloader() # get the loader that returns us data
batch = next(iter(dataloader)) # ask for the next batch of data
base, style = (batch['source'], batch['source_segmentation'])

print('Input Shape {}, {}'.format(base.size(), style.size()))

In [None]:
dm.set_active_split(2)

In [None]:
dataloader = dm.test_dataloader() # get the loader that returns us data
batch = next(iter(dataloader)) # ask for the next batch of data
base, style = (batch['source'], batch['source_segmentation'])

print('Input Shape {}, {}'.format(base.size(), style.size()))

In [None]:
dataloader

In [None]:
rgb_paths = glob.glob(
            os.path.join("./data", "exp", "rgb", "*.png")
        )
segmentation_paths = glob.glob(
            os.path.join("./data", "exp", "semseg", "*.png")
        )
target_paths = glob.glob(
            os.path.join("./data", "other_domains", "domainA", "*.jpg")
        )

In [None]:
import torch
import numpy as np
from sklearn.model_selection import KFold

In [None]:
cv_splitter = KFold(n_splits=5, random_state=None, shuffle=False)

In [None]:
rgb_paths_np = np.array(rgb_paths)
seg_paths_np = np.array(segmentation_paths)

In [None]:
rgb_train_splits = []
seg_train_splits = []
rgb_test_splits = []
seg_test_splits = []
train_datasets = []
test_datasets = []

In [None]:
for train_index, test_index in cv_splitter.split(rgb_paths):
    X_train, X_test = rgb_paths_np[train_index], rgb_paths_np[test_index]
    y_train, y_test = seg_paths_np[train_index], seg_paths_np[test_index]
    rgb_train_splits.append(X_train.tolist())
    rgb_test_splits.append(X_test.tolist())
    seg_train_splits.append(y_train.tolist())
    seg_test_splits.append(y_test.tolist())

In [None]:
for i in range(5):
    train_dataset = SourceDataset(
        rgb_train_splits[i],
        seg_train_splits[i],
        transform,
        "train")
    train_datasets.append(train_dataset)

for i in range(5):
    test_dataset = SourceDataset(
        rgb_test_splits[i],
        seg_test_splits[i],
        transform,
        "train")
    test_datasets.append(test_dataset)

In [None]:
from torch.utils.data import DataLoader, Dataset

In [None]:
train_datasets

In [None]:
dm = DataLoader(
            test_datasets[4],
            batch_size=8,
            shuffle=True,
            pin_memory=True,
            num_workers=4)
batch = next(iter(dm))

In [None]:
rgb_train, rgb_val, seg_train, seg_val = train_test_split(rgb_paths, segmentation_paths, test_size=0.2)

In [None]:
# Sanity Check
transform = ImageTransform(img_size=256)
batch_size = 8

dm = AgriDataModule(data_dir, transform, batch_size, domain=domain)
dm.prepare_data() # call first initialization function before we start asking for data
dm.setup() # call second initialization function before we start asking for data

dataloader = dm.test_dataloader() # get the loader that returns us data
batch = next(iter(dataloader)) # ask for the next batch of data
base, style = (batch['source'], batch['target'])

print('Input Shape {}, {}'.format(base.size(), style.size())) # check the shapes of the batch tensors

In [None]:
temp = make_grid(base, nrow=4, padding=2).permute(1, 2, 0).detach().numpy()
temp = temp * 0.5 + 0.5
temp = temp * 255.0
temp = temp.astype(int)

fig = plt.figure(figsize=(18, 8), facecolor='w')
plt.imshow(temp)
plt.axis('off')
plt.title('Source Domain')
plt.show()

temp = make_grid(style, nrow=4, padding=2).permute(1, 2, 0).detach().numpy()
temp = temp * 0.5 + 0.5
temp = temp * 255.0
temp = temp.astype(int)

fig = plt.figure(figsize=(18, 8), facecolor='w')
plt.imshow(temp)
plt.axis('off')
plt.title('Target Domain')
plt.show()

In [None]:
# Config  -----------------------------------------------------------------
transform = ImageTransform(img_size=256)
batch_size = 8
lr = {
    'G': 0.0002,
    'D': 0.0002
}
epoch = 180
seed = 42
reconstr_w = 10
id_w = 2

# DataModule  -----------------------------------------------------------------
dm = AgriDataModule(data_dir, transform, batch_size, domain=domain)
viz_set = AgriDataModule(data_dir, transform, 4, domain=domain)

G_basestyle = CycleGANGenerator(filter=32)
G_stylebase = CycleGANGenerator(filter=32)
D_base = CycleGANDiscriminator(filter=32)
D_style = CycleGANDiscriminator(filter=32)

# Init Weight  --------------------------------------------------------------
for net in [G_basestyle, G_stylebase, D_base, D_style]:
    init_weights(net, init_type='normal')

# LightningModule  --------------------------------------------------------------
vs = AgriDataModule(data_dir, transform, batch_size, domain=domain)
model = CycleGANSystem(G_basestyle, G_stylebase, D_base, D_style, lr, transform, reconstr_w, id_w)

# Trainer  --------------------------------------------------------------
trainer = Trainer(
    logger=False,
    max_epochs=epoch,
    gpus=1,
    checkpoint_callback=False,
    reload_dataloaders_every_epoch=True,
    num_sanity_val_steps=0,  # Skip Sanity Check
)

# Train
trainer.fit(model, datamodule=dm)