In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from torchvision.utils import save_image
from torch.autograd import Variable
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import nibabel as nib
import glob
import random
import argparse
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim

from model import Generator
from dataset import BratDataset
import cv2

In [2]:
transform = []

transform.append(T.ToTensor())
transform.append(T.Resize(256))
transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
transform = T.Compose(transform)

dataset = BratDataset('/home/bap/Downloads/BraTS2020_image_2D/image_2D/test', source_format='flair', transform=transform)
data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=1)

In [14]:
def label2onehot(labels, dim):
    """Convert label indices to one-hot vectors."""
    batch_size = labels.size(0)
    out = torch.zeros(batch_size, dim)
    out[np.arange(batch_size), labels.long()] = 1
    return out

def create_labels(c_org, c_dim=4):
    """Generate target domain labels for debugging and testing."""
    c_trg_list = []
    for i in range(c_dim):
        c_trg = label2onehot(torch.ones(c_org.size(0))*i, c_dim)
        c_trg_list.append(c_trg.to('cuda'))
    return c_trg_list

def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]."""
    out = (x + 1) / 2
    return out.clamp_(0, 1)


In [54]:
generator = Generator(64, 4, 6)
generator.load_state_dict(torch.load('/home/bap/Code/StarGANs---Generate-MRI-2D-images-master/stargan_brats2020/models/180000-G.ckpt', map_location=lambda storage, loc: storage))

<All keys matched successfully>

In [17]:
# from data_loader import get_loader
# old_trainloader = get_loader('/home/bap/Downloads/BraTS2020_image_2D/image_2D/test', image_size=256, batch_size=1, mode='test', num_workers=1)

In [None]:
# %matplotlib inline
# generator.cuda()
# item = next(iter(old_trainloader))
# with torch.no_grad():
#     fake = generator(item[0].cuda(), create_labels(torch.ones(16)*3, c_dim=4)[1].cuda())
#     # print(fake)
#     fake = (fake + 1) / 2
#     fake = fake[0].clamp_(0, 1)
#     save_image(fake, 'fake.png')
#     # fake = fake.permute(1, 2, 0)
#     # fake = fake.cpu().numpy()
#     # cv2.imwrite('fake.png', fake*255)
#     # plt.imshow(fake)

In [53]:
# %matplotlib inline
# generator.to('cuda')
# # item = next(iter(old_trainloader))
# with torch.no_grad():
#     for i, (x_real, c_org) in enumerate(old_trainloader):
#         c_trg_list = create_labels(c_org.cuda(), c_dim=4)
#         x_fake_list = [x_real.cuda()]
#         for c_trg in c_trg_list:
#             # ind = 1
#             x_fake_list.append(generator(x_real.cuda(), c_trg))
#             # fake = generator(item['source'][0].cuda(), create_labels(item['source'][1].cuda(), c_dim=4)[ind+1].cuda()).data.cpu()
#             # fake = (fake + 1) / 2
#             # fake = fake.clamp_(0, 1)
#             # save_image(fake, 'fake.png')
#             # source = item['source'][0]
#             # source = (source + 1) / 2
#             # source = source.clamp_(0, 1)
#             # save_image(source, 'source.png')
#             # target = item['target'][ind][0]
#             # target = (target + 1) / 2
#             # target = target.clamp_(0, 1)
#             # save_image(target, 'target.png')

#         x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
#         x_concat = (x_concat + 1) / 2
#         x_concat = x_concat.clamp_(0, 1)
#         save_image(x_concat, f'fake_imgs/fake{i}.png')
#         if i > 30:
#             break

In [81]:
%matplotlib inline
generator.to('cuda')
with torch.no_grad():
    for i, data in enumerate(data_loader):
        (x_real, c_org, path) = data['source']
        c_trg_list = create_labels(c_org.cuda(), c_dim=4)
        x_fake_list = []#[x_real.cuda()]
        target_list = [x_real]
        for j, c_trg in enumerate(c_trg_list):
            # if j != 3:
            #     continue
            # fake = generator(x_real.cuda(), c_trg)
            # fake = (fake + 1) / 2
            # fake = fake.clamp_(0, 1).detach().cpu()
            x_fake_list.append(generator(x_real.cuda(), c_trg))
            if j != 0:
                target = data['target'][j-1][0]
                # target = (target + 1) / 2
                # target = target.clamp_(0, 1)
                target_list.append(target)
            # print(torch.sum(fake - target))
            # save_image(fake, 'fake.png')
            # save_image(target, 'target.png')
            # source = item['source'][0]
            # source = (source + 1) / 2
            # source = source.clamp_(0, 1)
            # save_image(source, 'source.png')
            # target = item['target'][ind][0]
            # target = (target + 1) / 2
            # target = target.clamp_(0, 1)
            # save_image(target, 'target.png')
        #     break
        # break

        x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
        x_concat = torch.cat([x_concat, torch.cat(target_list, dim=3).data.cpu()], dim=2)
        x_concat = (x_concat + 1) / 2
        x_concat = x_concat.clamp_(0, 1)
        save_image(x_concat, f'fake_imgs/fake{i}.png')
        if i > 30:
            break

In [52]:
# %matplotlib inline
# generator.eval()
# generator.cuda()
# item = next(iter(old_trainloader))
# with torch.no_grad():
#     c_trg_list = create_labels(item[1].cuda(), c_dim=4)
#     x_fake_list = [item[0]]
#     for c_trg in c_trg_list:
#         # ind = 1
#         x_fake_list.append(generator(item[0].cuda(), c_trg).detach().cpu())
#         # fake = generator(item['source'][0].cuda(), create_labels(item['source'][1].cuda(), c_dim=4)[ind+1].cuda()).data.cpu()
#         # fake = (fake + 1) / 2
#         # fake = fake.clamp_(0, 1)
#         # save_image(fake, 'fake.png')
#         # source = item['source'][0]
#         # source = (source + 1) / 2
#         # source = source.clamp_(0, 1)
#         # save_image(source, 'source.png')
#         # target = item['target'][ind][0]
#         # target = (target + 1) / 2
#         # target = target.clamp_(0, 1)
#         # save_image(target, 'target.png')

#     x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
#     x_concat = (x_concat + 1) / 2
#     x_concat = x_concat.clamp_(0, 1)
#     save_image(x_concat, 'fake.png', nrow=1, padding=0)

In [47]:
# %matplotlib inline
# generator.eval()
# generator.cuda()
# item = next(iter(data_loader))
# with torch.no_grad():
#     print(item['source'][0].shape)
#     print(item['source'][1])
#     print(torch.sum(item['source'][0]))
#     c_trg_list = create_labels(item['source'][1].cuda(), c_dim=4)
#     x_fake_list = [item['source'][0]]
#     for c_trg in c_trg_list:
#         # ind = 1
#         x_fake_list.append(generator(item['source'][0].cuda(), c_trg).detach().cpu())
#         # fake = generator(item['source'][0].cuda(), create_labels(item['source'][1].cuda(), c_dim=4)[ind+1].cuda()).data.cpu()
#         # fake = (fake + 1) / 2
#         # fake = fake.clamp_(0, 1)
#         # save_image(fake, 'fake.png')
#         # source = item['source'][0]
#         # source = (source + 1) / 2
#         # source = source.clamp_(0, 1)
#         # save_image(source, 'source.png')
#         # target = item['target'][ind][0]
#         # target = (target + 1) / 2
#         # target = target.clamp_(0, 1)
#         # save_image(target, 'target.png')

#     x_concat = torch.cat(x_fake_list, dim=3).data.cpu()
#     x_concat = (x_concat + 1) / 2
#     x_concat = x_concat.clamp_(0, 1)
#     save_image(x_concat, 'fake.png', nrow=1, padding=0)

torch.Size([1, 3, 256, 256])
tensor([0])
tensor(-168382.7500)
