In [89]:
import os
import numpy as np
import datetime
import argparse
import sys
## Imports for plotting
import matplotlib.pyplot as plt
import matplotlib

## PyTorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

## Torchvision
from torchvision import datasets, transforms

import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger

from scipy.optimize import linear_sum_assignment
from scipy.sparse import coo_matrix

import random
import gzip
import shutil
from scipy.ndimage import shift

torch.set_float32_matmul_precision('medium')
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [90]:
h = {  # hyperparameters
    'dataset': 'RADIO',  #'MNIST_SHIFT',  #'MNIST',
    'in_channels': 2,
    'hidden_channels': 32,
    'epochs': 2,
    'learning_rate': 1e-4,
    'checkpoint_path': './saved_models',
    'dataset_path': '../data/',
    'batch_size': 128,
    'num_workers': 8,
    'persistent_workers': True,
    'pin_memory': True,
    'train': False,  # Set this to false if you only want to evaluate the model
    'use_wandb': False
}

In [91]:
# Check if the script is being run in a Jupyter notebook
if 'ipykernel' not in sys.modules:
    # Parse command-line arguments
    parser = argparse.ArgumentParser()
    for key, value in h.items():
        if isinstance(value, bool):
            parser.add_argument(f'--{key}', type=bool, default=value)
        elif isinstance(value, int):
            parser.add_argument(f'--{key}', type=int, default=value)
        elif isinstance(value, float):
            parser.add_argument(f'--{key}', type=float, default=value)
        else:  # for str and potentially other types
            parser.add_argument(f'--{key}', type=type(value), default=value)
    args = parser.parse_args()

    # Overwrite the default hyperparameters with the command-line arguments
    h.update(vars(args))

# In terminal, run:
# python main.py --model SinkhornConvNetV2 --epochs 30

In [98]:
# if h['model'] == 'Sink':
#     SinkhornConvNet = SinkhornConvNetV1
# elif h['model'] == 'SinkV2':
#     SinkhornConvNet = SinkhornConvNetV2
# elif h['model'] == 'SinkV3':
#     SinkhornConvNet = SinkhornConvNetV3
# else:
#     raise ValueError(f"Unknown model: {h['model']}")

In [99]:
# # # Example usage
# model = SinkhornConvNet(
#     in_channels=1, num_pieces=2, image_size=28, hidden_channels=32, kernel_size=5, tau=0.1, n_sink_iter=20)
# 
# random_pieces = torch.rand((64, 4, 1, 14, 14))
# res1, res2 = model(random_pieces)
# res1.shape, res2.shape

torch.Size([64, 1568])
torch.Size([64, 6272])


(torch.Size([64, 4, 1, 14, 14]), torch.Size([64, 4, 4]))

In [101]:
device = torch.device('cuda')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

if h['dataset'] == 'MNIST':
    trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
    testset = datasets.MNIST(root='../data', train=False, transform=transform)
elif h['dataset'] == 'MNIST_SHIFT':
    trainset = MNIST_SHIFT(root='../data', train=True, download=True, transform=transform)
    testset = MNIST_SHIFT(root='../data', train=False, transform=transform)
else:
    raise ValueError(f"Unknown dataset: {h['dataset']}")

train_loader = DataLoader(trainset, h['batch_size'], drop_last=True, shuffle=True,
                          num_workers=h['num_workers'], persistent_workers=h['persistent_workers'],
                          pin_memory=h['pin_memory'])
test_loader = DataLoader(testset, h['batch_size'], drop_last=False, shuffle=False,
                         num_workers=h['num_workers'], persistent_workers=h['persistent_workers'],
                         pin_memory=h['pin_memory'])

In [103]:
class CustomCallbacks(pl.Callback):
    def __init__(self, plot_every_n_epoch, num_pieces, wandb_logger: WandbLogger):
        super().__init__()
        self.plot_every_n_epoch = plot_every_n_epoch
        self.num_pieces = num_pieces
        self.wandb_logger = wandb_logger

    def assemble_image(self, pieces):
        # pieces: [num_pieces, channels, height // num_pieces, width // num_pieces]
        num_pieces, channels, piece_height, piece_width = pieces.shape
        num_pieces_side = int(num_pieces ** 0.5)

        # Reshape to [num_pieces_side, num_pieces_side, channels, piece_height, piece_width]
        pieces = pieces.view(num_pieces_side, num_pieces_side, channels, piece_height, piece_width)

        # Permute to [channels, num_pieces_side, piece_height, num_pieces_side, piece_width]
        pieces = pieces.permute(2, 0, 3, 1, 4)

        # Reshape to [channels, height, width]
        image = pieces.contiguous().view(channels, num_pieces_side * piece_height, num_pieces_side * piece_width)

        return image

    def log_images(self, trainer, pl_module, loader, prefix):
        if (trainer.current_epoch % self.plot_every_n_epoch == 0) or (trainer.current_epoch == h['epochs'] - 1):
            pl_module.eval()
            image_batch, label_batch = next(iter(loader))
            pieces, random_pieces, _ = batch_chunk_image(image_batch, self.num_pieces)
            pieces, random_pieces = pieces.to(pl_module.device), random_pieces.to(pl_module.device)

            ordered_pieces, permutation_matrices = pl_module(random_pieces)
            # Assemble the pieces into a single image before logging
            nb_ims = min(ordered_pieces.shape[0], 10)
            for i in range(nb_ims):
                initial_image = self.assemble_image(random_pieces[i])
                ordered_image = self.assemble_image(ordered_pieces[i])
                ground_truth_image = self.assemble_image(pieces[i])

                # log to wandb
                self.wandb_logger.experiment.log(
                    {f"{prefix}_predicted_image/img_{i}": wandb.Image(ordered_image.cpu().squeeze(),
                                                                      caption=f"Label {label_batch[i]}"),
                     f"{prefix}_ground_truth/img_{i}": wandb.Image(ground_truth_image.cpu().squeeze(),
                                                                   caption=f"Label {label_batch[i]}"),
                     f"{prefix}_input/img_{i}": wandb.Image(initial_image.cpu().squeeze(),
                                                            caption=f"Label {label_batch[i]}")},
                    step=trainer.global_step)

            pl_module.train()

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: SinkhornConvNet):
        self.log_images(trainer, pl_module, train_loader, "TRAIN")

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: SinkhornConvNet):
        pass

    def on_train_end(self, trainer, pl_module):
        self.log_images(trainer, pl_module, test_loader, "TEST_SET_train_end")


In [104]:
# Initialize the model
model = SinkhornConvNet(in_channels=h['in_channels'],
                        num_pieces=h['num_pieces'],
                        image_size=h['image_size'],
                        hidden_channels=h['hidden_channels'],
                        kernel_size=h['kernel_size'],
                        tau=h['tau'],
                        n_sink_iter=h['n_sink_iter'])

if h['train']:
    if h['use_wandb']:
        date_identifier = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        wandb_logger = WandbLogger()
        wandb.init(project='sinkhorn_netw', config=h, entity='oboii',
                   name=f'{h["model"]}_{h["dataset"]}_{h["num_pieces"]}_pieces_{date_identifier}')

    trainer = pl.Trainer(
        max_epochs=h['epochs'],
        callbacks=[CustomCallbacks(plot_every_n_epoch=1, num_pieces=h['num_pieces'],
                                   wandb_logger=wandb_logger)] if h['use_wandb'] else None,
        logger=wandb_logger if h['use_wandb'] else None,
        limit_train_batches=h['dataset_percent'],
        limit_val_batches=h['dataset_percent'],
        limit_test_batches=h['dataset_percent'])

    # Train the model
    trainer.fit(model, train_loader, val_dataloaders=test_loader)

    # Save the model
    torch.save(model.state_dict(), os.path.join(h['checkpoint_path'], 'model.pth'))

    # Finish the run if we it was running
    if wandb.run is not None:
        wandb.finish()
