In [1]:
import logging
import pathlib
import random
import shutil
import time
import argparse

In [2]:
from fastmri_common.args import Args
from fastmri_common.subsample import MaskFunc
from fastmri_data import transforms
from fastmri_data.mri_data import SliceData

In [3]:
# from pix2pix_data import create_datset
from pix2pix_models import create_model, get_option_setter
from pix2pix_util.visualizer import Visualizer
from pix2pix_options.train_options import TrainOptions
from pix2pix_options.test_options import TestOptions

In [4]:
import torch
from torch import nn as nn
import torchvision
import numpy as np
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

In [5]:
from torchvision.transforms import RandomCrop, ToTensor, ToPILImage, Compose

In [6]:
from torchvision.utils import make_grid

In [7]:
from tqdm import tqdm_notebook as tqdm

In [8]:
class DataTransform:
    """
    Data Transformer for training U-Net models.
    """

    def __init__(self, mask_func, resolution, which_challenge, use_seed=True):
        """
        Args:
            mask_func (common.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
            resolution (int): Resolution of the image.
            which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset.
            use_seed (bool): If true, this class computes a pseudo random number generator seed
                from the filename. This ensures that the same mask is used for all the slices of
                a given volume every time.
        """
        if which_challenge not in ('singlecoil', 'multicoil'):
            raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
        self.mask_func = mask_func
        self.resolution = resolution
        self.which_challenge = which_challenge
        self.use_seed = use_seed

    def __call__(self, kspace, target, attrs, fname, slice_):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        return image, target, mean, std, attrs['norm'].astype(np.float32)


def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data


def create_data_loaders(args):
    dev_data, train_data = create_datasets(args)
    display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)]

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dataset=dev_data,
        batch_size=args.batch_size,
        num_workers=8,
        pin_memory=True,
    )
    display_loader = DataLoader(
        dataset=display_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader, dev_loader, display_loader

In [9]:
class FastMriToPix2PixTransform(DataTransform):
    
    def get_crop_positions(self, image, crop_size=256):
        r, c = image.shape[-2:]
        for i, dim in enumerate(image.shape[-2:]):
            if dim < crop_size:
                raise ValueError("Dimension: {} less than size: {}".format(i, crop_size))
        rp = np.random.randint(r-256)
        cp = np.random.randint(c-256)
        return rp, cp
    
    def transform_image(self, image, crop_params, crop_size=256):
        r, c = crop_params
        return image[...,r:r+crop_size, c:c+crop_size].unsqueeze(0)
    
    def __init__(self, mask_func, resolution, which_challenge, use_seed=True):
        super().__init__(mask_func, resolution, which_challenge, use_seed)
    def __call__(self, kspace, target, attrs, fname, slice_):    
        image, target, mean, std, norm = super().__call__(
            kspace, target, attrs, fname, slice_
        )
        
        crop_params = self.get_crop_positions(image)
        
        return {
            "A": self.transform_image(image, crop_params),
            "B": self.transform_image(target, crop_params),
            "A_paths": "DEBUG! Do not use",
            "B_paths": "DEBUG! Do not use"
        }

In [10]:
fastMriArgs = Args()

In [11]:
args = fastMriArgs.parse_args(
    [
        "--challenge", "singlecoil",
        "--data-path", "./data/"
    ]
)

In [12]:
train_mask = MaskFunc(
    args.center_fractions,
    args.accelerations
)
dev_mask = MaskFunc(
    args.center_fractions,
    args.accelerations
)

In [13]:
train_data = SliceData(
    root=args.data_path / f'{args.challenge}_train',
    transform=FastMriToPix2PixTransform(
        train_mask, args.resolution, args.challenge, use_seed=True
    ),
    sample_rate=args.sample_rate,
    challenge=args.challenge,    
)

In [14]:
val_data = SliceData(
    root=args.data_path / f'{args.challenge}_val',
    transform=FastMriToPix2PixTransform(
        dev_mask, args.resolution, args.challenge, use_seed=True
    ),
    sample_rate=args.sample_rate,
    challenge=args.challenge,
)

In [15]:
train_options = TrainOptions()
test_options = TestOptions()

```
# !python train.py --dataroot ./datasets/facades --name facades_pix2pix_mason --model pix2pix --direction BtoA

!python train.py --dataroot ./datasets/facades --name facades_label2photo --model pix2pix --batch_size 16
```

In [27]:
train_args = [
    "--dataroot", "./",
    "--name", "pixel_low_lr",
    "--model", "pix2pix",
    "--input_nc", "1",
    "--output_nc", "1",
    "--no_flip",
    "--display_id", "-1",
    "--isTrain", "True",
    "--gpu_ids", "0",
    "--batch_size", "16",
    "--netG", "unet_256",
    "--netD", "pixel",
    "--lr_policy", "cosine",
    "--lr_g", '0.00001',
    "--lr_d", '0.000005',
    "--continue_train",
    "--epoch", "1",
#     "--gan_mode", "wgangp",
    "--verbose",
]

In [28]:
test_args = [
    "--dataroot", "./",
    "--name", "val",
    "--model", "pix2pix",
    "--input_nc", "1",
    "--output_nc", "1",
    "--no_flip",
    "--isTrain", "False",
    "--gpu_ids", "0",
    "--batch_size", "64",
    "--verbose",
]

In [29]:
train_opts = argparse.ArgumentParser()
test_opts = argparse.ArgumentParser()

In [30]:
train_opts = train_options.initialize(train_opts)
test_opts = test_options.initialize(test_opts)

In [31]:
train_opts.add_argument("--isTrain", default=True)
test_opts.add_argument('--isTrain', default=False)

_StoreAction(option_strings=['--isTrain'], dest='isTrain', nargs=None, const=None, default=False, type=None, choices=None, help=None, metavar=None)

In [32]:
train_opts = get_option_setter("pix2pix")(train_opts)
test_opts = get_option_setter("pix2pix")(test_opts)

In [33]:
train_opts = train_opts.parse_args(train_args)
test_opts = test_opts.parse_args(test_args)

In [23]:
train_loader = DataLoader(
    dataset=train_data,
    batch_size=train_opts.batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)
test_loader = DataLoader(
    dataset=val_data,
    batch_size=test_opts.batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

In [24]:
model = create_model(train_opts)

initialize network with normal
initialize network with normal
model [Pix2PixModel] was created


In [34]:
model.setup(train_opts)

Hurrah!
loading the model from ./checkpoints/pixel_low_lr/1_net_G.pth
loading the model from ./checkpoints/pixel_low_lr/1_net_D.pth
---------- Networks initialized -------------
UnetGenerator(
  (model): UnetSkipConnectionBlock(
    (model): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): UnetSkipConnectionBlock(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace)
          (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): UnetSkipConnectionBlock(
            (model): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace)
              (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): 

In [35]:
num_epochs = 10

In [36]:
comment = "Discriminator_LRG_1e-5_LRD_5e-6_D-SGD_pixel_low_lr"

In [37]:
writer = SummaryWriter(comment=comment)

In [38]:
def write_loss_scalars(writer, losses_dict, step):
    for key in losses_dict.keys():
        writer.add_scalar(key, losses_dict[key], step)

In [39]:
dataset_size = len(train_loader)

In [40]:
test_size = len(test_loader)

In [None]:
for epoch_ in tqdm(range(num_epochs)):
    epoch = epoch_+1
    
    for i, batch in enumerate(tqdm(train_loader)):
        model.set_input(batch)
        model.optimize_parameters()
        losses = model.get_current_losses()
        write_loss_scalars(writer, losses, epoch*dataset_size+i)
        
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader)):
            outputs = model.netG(batch['A'].to(torch.device("cuda:0")))
            writer.add_image(f"generated_"+comment, make_grid(outputs.cpu()), epoch)
            writer.add_image(f"subsampled_"+comment, make_grid(batch['A']), epoch)
            writer.add_image(f"target_"+comment, make_grid(batch['B']), epoch)
            break

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2172), HTML(value='')))

save_filename = "1_net_G.pth"

model.model_names

save_path = model.save_dir+"/"+save_filename

torch.save(model.netG.cpu().state_dict(), save_path)

model1 = create_model(train_opts)

model1.load_networks(11)