In [None]:
import sys
import os
import scipy.ndimage as nd
import scipy.io as io
import numpy as np
import matplotlib.pyplot as plt
import skimage.measure as sk
import utils
plt.style.use('ggplot')

In [None]:
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms as tfs
from torch.utils import data

import torch

#### Load datasets

In [None]:
imagenet_data = datasets.ImageFolder('data/png/',
                                     transform=tfs.Compose([tfs.RandomChoice([tfs.ColorJitter()]),
                                                           tfs.Resize((112, 112)),
                                                           tfs.Grayscale(1)]))
      
data_loader = data.DataLoader(imagenet_data, batch_size=4, shuffle=True)

In [None]:
fig = plt.figure()

for i in range(len(imagenet_data)):
    sample = imagenet_data[i + 10]
    print(sample)
    #print(i, sample.shape)
    ax = plt.subplot(1, 5, i + 1)
    plt.tight_layout()
    plt.imshow(np.asarray(sample[0]))
    ax.set_title('Sample {}'.format(i))
    ax.axis('off')

    if i == 4:
        plt.show()
        break

#### Define generator & GAN

In [None]:
from models.generator import _G
from models.discriminator import _D
from models.gan import GAN
import utils

In [None]:
import time
import pickle

transforms = [tfs.RandomAffine(0, scale=(0.3, 1.)),
              tfs.Resize((112, 112)),
              tfs.Grayscale(1),
              tfs.ToTensor()]

gan = GAN(epochs=100, input_h_w=112, data_path='data/png_clasificados/',
          transforms=transforms)
gan.train()
print("Training finished!")

# visualize learned generator
gan.visualize_results(gan.epoch)


In [None]:
from IPython.display import Image

with open('/tmp/GAN_epochs.gif','rb') as f:
    display(Image(data=f.read(), format='png'))

#### Process type of vessels for classification

In [None]:
import pandas as pd
import shutil

In [None]:
def create_df_from_files(path='data/perfiles_CATA/clases/'):
    l = list()  
    for class_, filename in enumerate(os.listdir(path), 1):
        with open(os.path.join(path, filename)) as f:
            lines = f.readlines()
            for id_ in lines:
                l.append((id_.rstrip(), class_))
    df_classes = pd.DataFrame(l, columns=['id', 'class'])

    return df_classes

In [None]:
def create_folder_pytorch_format(df, destination, path):
    for row in df.iterrows():
        directory = os.path.join(destination, str(row[1][1]))
        if not os.path.exists(directory):
            os.makedirs(directory)
        name = row[1][0] + '.png'
        for root, dirs, files in os.walk(path):
            if name in files:
                print(os.path.join(root, name))
                shutil.copy(os.path.join(root, name), destination + str(row[1][1]))

In [None]:
df = create_df_from_files()
destination = "/home/celia/Code/vasijas_repo/data/perfiles_CATA/png_clasificados/"
path = "/home/celia/Code/vasijas_repo/data/perfiles_CATA/png"
create_folder_pytorch_format(df, destination, path)

In [None]:
df