In [1]:
import logging
import pathlib
import random
import shutil
import time
import argparse

In [2]:
from fastmri_common.args import Args
from fastmri_common.subsample import MaskFunc
from fastmri_data import transforms
from fastmri_data.mri_data import SliceData

In [3]:
# from pix2pix_data import create_datset
from pix2pix_models import create_model
from pix2pix_util.visualizer import Visualizer
from pix2pix_options.train_options import TrainOptions

In [4]:
import torch
from torch import nn as nn
import torchvision
import numpy as np
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

In [5]:
from tqdm import tqdm_notebook as tqdm

In [6]:
class DataTransform:
    """
    Data Transformer for training U-Net models.
    """

    def __init__(self, mask_func, resolution, which_challenge, use_seed=True):
        """
        Args:
            mask_func (common.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
            resolution (int): Resolution of the image.
            which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset.
            use_seed (bool): If true, this class computes a pseudo random number generator seed
                from the filename. This ensures that the same mask is used for all the slices of
                a given volume every time.
        """
        if which_challenge not in ('singlecoil', 'multicoil'):
            raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
        self.mask_func = mask_func
        self.resolution = resolution
        self.which_challenge = which_challenge
        self.use_seed = use_seed

    def __call__(self, kspace, target, attrs, fname, slice_):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image = transforms.complex_abs(image)
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        target = transforms.to_tensor(target)
        # Normalize target
        target = transforms.normalize(target, mean, std, eps=1e-11)
        target = target.clamp(-6, 6)
        return image, target, mean, std, attrs['norm'].astype(np.float32)


def create_datasets(args):
    train_mask = MaskFunc(args.center_fractions, args.accelerations)
    dev_mask = MaskFunc(args.center_fractions, args.accelerations)

    train_data = SliceData(
        root=args.data_path / f'{args.challenge}_train',
        transform=DataTransform(train_mask, args.resolution, args.challenge),
        sample_rate=args.sample_rate,
        challenge=args.challenge
    )
    dev_data = SliceData(
        root=args.data_path / f'{args.challenge}_val',
        transform=DataTransform(dev_mask, args.resolution, args.challenge, use_seed=True),
        sample_rate=args.sample_rate,
        challenge=args.challenge,
    )
    return dev_data, train_data


def create_data_loaders(args):
    dev_data, train_data = create_datasets(args)
    display_data = [dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)]

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    dev_loader = DataLoader(
        dataset=dev_data,
        batch_size=args.batch_size,
        num_workers=8,
        pin_memory=True,
    )
    display_loader = DataLoader(
        dataset=display_data,
        batch_size=16,
        num_workers=8,
        pin_memory=True,
    )
    return train_loader, dev_loader, display_loader

In [7]:
class FastMriToPix2PixTransform(DataTransform):
    def __init__(self, mask_func, resolution, which_challenge, use_seed=True):
        super().__init__(mask_func, resolution, which_challenge, use_seed)
    def __call__(self, kspace, target, attrs, fname, slice_):    
        image, target, mean, std, norm = super().__call__(
            kspace, target, attrs, fname, slice_
        )
        pad = nn.ConstantPad2d(padding=1, value=0)
        return {
            "A": pad(image),
            "B": pad(image),
            "A_paths": "DEBUG! Do not use",
            "B_paths": "DEBUG! Do not use"
        }

In [8]:
fastMriArgs = Args()

In [9]:
args = fastMriArgs.parse_args(
    [
        "--challenge", "singlecoil",
        "--data-path", "./data/"
    ]
)

In [10]:
train_mask = MaskFunc(
    args.center_fractions,
    args.accelerations
)

In [11]:
train_data = SliceData(
    root=args.data_path / f'{args.challenge}_val',
    transform=FastMriToPix2PixTransform(
        train_mask, args.resolution, args.challenge, use_seed=True
    ),
    sample_rate=args.sample_rate,
    challenge=args.challenge,    
)

In [12]:
train_options = TrainOptions()

```
# !python train.py --dataroot ./datasets/facades --name facades_pix2pix_mason --model pix2pix --direction BtoA

!python train.py --dataroot ./datasets/facades --name facades_label2photo --model pix2pix --batch_size 16
```

In [13]:
train_args = [
    "--dataroot", "./",
    "--name", "train",
    "--model", "pix2pix",
    "--input_nc", "1",
    "--output_nc", "1",
    "--no_flip",
    "--display_id", "-1",
    "--isTrain", "True",
    "--gpu_ids", "0",
    "--batch_size", "32",
    "--verbose",
]

In [14]:
test_args = [
    "--dataroot", "./",
    "--name", "train",
    "--model", "pix2pix",
    "--input_nc", "1",
    "--output_nc", "1",
    "--no_flip",
    "--display_id", "-1",
    "--isTrain", "False",
    "--gpu_ids", "0",
    "--batch_size", "64",
    "--verbose",
]

In [15]:
options = argparse.ArgumentParser()

In [16]:
train_options = TrainOptions()

In [17]:
options = train_options.initialize(options)

In [18]:
options.add_argument("--isTrain", default=True)

_StoreAction(option_strings=['--isTrain'], dest='isTrain', nargs=None, const=None, default=True, type=None, choices=None, help=None, metavar=None)

In [19]:
options = options.parse_args(train_args)

In [20]:
from pix2pix_models.pix2pix_model import Pix2PixModel

In [21]:
model = Pix2PixModel(options)

initialize network with normal
initialize network with normal


In [22]:
train_loader = DataLoader(
    dataset=train_data,
    batch_size=options.batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

In [23]:
for batch in train_loader:
    batch1 = batch
    break

In [26]:
fakeB = model.netG(batch1['A'])

AssertionError: 3D tensors expect 2 values for padding