In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#!unzip /content/drive/MyDrive/out_files.zip
#!unzip /content/drive/MyDrive/MindBigData-Imagenet-IN.zip

In [None]:
import os
import librosa
import librosa.display
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import spectrogram
import numpy as np
from PIL import Image
from tqdm import tqdm


import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import cv2

from IPython.display import display

In [None]:
#@title Conv Blocks
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 
    padding=padding)


def conv_n(in_channels, out_channels, kernel_size, stride=1, padding=0, inst_norm=False):
    if inst_norm == True:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
        stride=stride, padding=padding), nn.InstanceNorm2d(out_channels, 
        momentum=0.1, eps=1e-5),)
    else:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, 
        stride=stride, padding=padding), nn.BatchNorm2d(out_channels, 
        momentum=0.1, eps=1e-5),)

def tconv(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0,):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 
    padding=padding, output_padding=output_padding)
    
def tconv_n(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, inst_norm=False):
    if inst_norm == True:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 
        stride=stride, padding=padding, output_padding=output_padding), 
        nn.InstanceNorm2d(out_channels, momentum=0.1, eps=1e-5),)
    else:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 
        stride=stride, padding=padding, output_padding=output_padding), 
        nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5),)

In [None]:
#@title Generator
class Gen(nn.Module):
    def __init__(self, dim_c=5, dim_g=32, inst_norm=False):
        super(Gen, self).__init__()
        self.n1 = conv(dim_c, dim_g, 4, 2, 1) 
        self.n2 = conv_n(dim_g, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.n3 = conv_n(dim_g*2, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.n4 = conv_n(dim_g*4, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.n5 = conv_n(dim_g*4, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        #self.n6 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        #self.n7 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        #self.n8 = conv(dim_g*8, dim_g*8, 4, 2, 1)
        self.m1 = tconv_n(dim_g*4, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.m2 = tconv_n(dim_g*4*2, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.m3 = tconv_n(dim_g*4*2, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.m4 = tconv_n(dim_g*4, dim_g, 4, 2, 1, inst_norm=inst_norm)
        self.m5 = tconv(dim_g*1*2, 3, 4, 2, 1)
        #self.m6 = tconv_n(dim_g*4*2, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        #self.m7 = tconv_n(dim_g*2*2, dim_g*1, 4, 2, 1, inst_norm=inst_norm)
        #self.m8 = tconv(dim_g*1*2, dim_c, 4, 2, 1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        n1 = self.n1(x)
        n2 = self.n2(F.leaky_relu(n1, 0.2))
        n3 = self.n3(F.leaky_relu(n2, 0.2))
        n4 = self.n4(F.leaky_relu(n3, 0.2))
        n5 = self.n5(F.leaky_relu(n4, 0.2))
        m1 = torch.cat([F.dropout(self.m1(F.relu(n5)), 0.5, training=True), n4], 1)
        m2 = torch.cat([F.dropout(self.m2(F.relu(m1)), 0.5, training=True), n3], 1)
        m3 = torch.cat([F.dropout(self.m3(F.relu(m2)), 0.5, training=True), n2], 1)
        m4 = torch.cat([self.m4(F.relu(m3)), n1], 1)
        m5 = self.m5(F.relu(m4))
        return self.tanh(m5)

In [None]:
#@title Discriminator

class Disc(nn.Module):
    def __init__(self, dim_cx=5, dim_cy=3, dim_d=32, inst_norm=False): 
        super(Disc, self).__init__()
        self.c1 = conv(dim_cx+dim_cy, dim_d, 4, 2, 1) 
        self.c2 = conv_n(dim_d, dim_d*2, 4, 2, 1, inst_norm=inst_norm)
        self.c3 = conv_n(dim_d*2, dim_d*4, 4, 2, 1, inst_norm=inst_norm)
        self.c4 = conv_n(dim_d*4, dim_d*8, 4, 1, 1, inst_norm=inst_norm)
        self.c5 = conv(dim_d*8, 1, 4, 1, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x, y):
        xy=torch.cat([x,y],dim=1)
        xy=F.leaky_relu(self.c1(xy), 0.2)
        xy=F.leaky_relu(self.c2(xy), 0.2)
        xy=F.leaky_relu(self.c3(xy), 0.2)
        xy=F.leaky_relu(self.c4(xy), 0.2)
        xy=self.c5(xy)
        return self.sigmoid(xy)

def weights_init(z):
    cls_name =z.__class__.__name__
    if cls_name.find('Conv')!=-1 or cls_name.find('Linear')!=-1: 
        nn.init.normal_(z.weight.data, 0.0, 0.02)
        nn.init.constant_(z.bias.data, 0)
    elif cls_name.find('BatchNorm')!=-1:
        nn.init.normal_(z.weight.data, 1.0, 0.02)
        nn.init.constant_(z.bias.data, 0)

In [None]:
def from_spec(Xdb, p, hop_length=6, win_length=32):
  Sinv = librosa.db_to_amplitude(Xdb)
  ts = librosa.istft(Sinv * p, hop_length=hop_length, win_length=win_length)
  return ts

def to_spec(ts, n_fft=128, hop_length=6, win_length=32):
  X = librosa.stft(ts, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
  Smag, p = librosa.magphase(X)
  Xdb = librosa.amplitude_to_db(Smag, top_db=None)
  Xdb = cv2.resize(Xdb, dsize=(64, 64), interpolation=cv2.INTER_CUBIC)
  return Xdb, p

def scale_minmax(X, min=0.0, max=1.0):
    X_std = (X - X.min()) / (X.max() - X.min())
    X_scaled = X_std * (max - min) + min
    return X_scaled

In [None]:
convert_to_spec_matrix = lambda brain_ts: np.concatenate([np.expand_dims(scale_minmax(to_spec(brain_ts[:,i], n_fft=128, hop_length=6, win_length=64)[0], min=-1.0, max=1.0), 0) for i in range(brain_ts.shape[1])], 0)

def csv_to_spec_matrix(csv):
  brain_ts = pd.read_csv(csv, index_col=0, header=None).T[-360:]
  brain_ts = (brain_ts - brain_ts.mean()) / brain_ts.std()
  brain_ts = brain_ts.reset_index(drop=True).to_numpy()
  spec_matrix = convert_to_spec_matrix(brain_ts)
  return spec_matrix

In [None]:
to_imgs = lambda xy: [Image.fromarray((((el * 0.5) + 0.5) * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().numpy()).resize((256,256)) for el in xy]

In [None]:
class TSDataset(Dataset):
    def __init__(self, data_split_path):
        self.input_data, self.output_data = torch.load(data_split_path)

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        return self.input_data[idx], self.output_data[idx]

In [None]:
new_imgs_csvs_labels_path = '/content/drive/MyDrive/TSS Tensors/new_imgs_csvs_labels.pt'
ts_images_path = '/content/drive/MyDrive/TSS Tensors/ts_images.pt'
img_array_path = '/content/drive/MyDrive/TSS Tensors/images.pt'
train_path = '/content/drive/MyDrive/TSS Tensors/train.pt'
test_path = '/content/drive/MyDrive/TSS Tensors/test.pt'

transform = transforms.Compose([transforms.Resize((64,64)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),])

new_imgs, csvs, labels = torch.load(new_imgs_csvs_labels_path)

'''
if os.path.exists(ts_images_path):
  input_data = torch.load(ts_images_path)
else:
  input_data = [torch.Tensor(csv_to_spec_matrix(csv)) for csv in tqdm(csvs)]
  torch.save(input_data, ts_images_path)

if os.path.exists(img_array_path):
  output_data = torch.load(img_array_path)
else:
  output_data = [transform(Image.open(img).convert('RGB')) for img in tqdm(new_imgs)]
  torch.save(output_data, img_array_path)

input_train, input_test, output_train, output_test = train_test_split(input_data, output_data, test_size=0.15, random_state=42)
torch.save((input_train, output_train), train_path)
torch.save((input_test, output_test), test_path)
''';

In [None]:
bs = 16
inst_norm = True if bs==1 else False  # instance normalization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Gen().to(device)
discriminator = Disc().to(device)
train_ds = TSDataset(train_path)
test_ds = TSDataset(test_path)

train_dataloader = DataLoader(train_ds, batch_size=bs, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=bs, shuffle=False)

In [None]:
BCE = nn.BCELoss()
L1 = nn.L1Loss() 

Gen_optim = optim.AdamW(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
Disc_optim = optim.AdamW(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

In [None]:
img_list = []
Disc_losses = Gen_losses = Gen_GAN_losses = Gen_L1_losses = []
iter_per_plot = 100
epochs = 50
L1_lambda = 100.0

for ep in range(epochs):
    for i, data in enumerate(train_dataloader):
        size = data[0].shape[0]
        x, y = data
        x, y = x.to(device), y.to(device)
        r_masks = torch.ones(size,1,6,6).to(device)
        f_masks = torch.zeros(size,1,6,6).to(device)
        # disc
        discriminator.zero_grad()
        #real_patch
        r_patch = discriminator(x, y)
        r_gan_loss = BCE(r_patch, r_masks)
        fake = generator(x)
        #fake_patch
        f_patch = discriminator(x, fake.detach())
        f_gan_loss = BCE(f_patch, f_masks)
        Disc_loss = r_gan_loss + f_gan_loss
        Disc_loss.backward()
        Disc_optim.step()
        # gen
        generator.zero_grad()
        f_patch = discriminator(x, fake)
        f_gan_loss = BCE(f_patch, r_masks)
        L1_loss = L1(fake, y)
        Gen_loss = f_gan_loss + L1_lambda * L1_loss
        Gen_loss.backward()
    
        Gen_optim.step()

        if (i+1)%iter_per_plot == 0:
          print('Epoch [{}/{}], Step [{}/{}], disc_loss: {:.4f}, gen_loss: {:.4f},Disc(real): {:.2f}, Disc(fake):{:.2f}, gen_loss_gan:{:.4f}, gen_loss_L1:{:.4f}'.format(ep, epochs, i+1, len(train_dataloader), Disc_loss.item(), Gen_loss.item(), r_patch.mean(), f_patch.mean(), f_gan_loss.item(), L1_loss.item()))
        
    for data in test_dataloader:
      x, y = data
      x, y = x.to(device), y.to(device)
      with torch.no_grad():
        generator.eval()
        fake = generator(x)
        generator.train()
      break

    faket, yt = to_imgs(fake), to_imgs(y)

    for fi, yi in zip(faket, yt):
      display(fi)
      display(yi)
      print('-'*20)

In [1]:
for i, data in enumerate(test_dataloader):
  x, y = data
  x, y = x.to(device), y.to(device)
  with torch.no_grad():
    generator.eval()
    fake = generator(x)
    generator.train()
  
  if i > 2:
    break

faket, yt = to_imgs(fake), to_imgs(y)

for fi, yi in zip(faket, yt):
  display(fi)
  display(yi)
  print('-'*20)

NameError: ignored

In [None]:
#ts = brain_ts[:,2]
#Xdb, p = to_spec(ts, n_fft=128, hop_length=6, win_length=64)
#print(Xdb.shape)
#Image.fromarray(scale_minmax(Xdb, min=0.0, max=255.0).astype(np.uint8))
#plt.figure(figsize=(14, 5))
#librosa.display.specshow(Xdb, sr=128, x_axis='time', y_axis='hz')
#plt.colorbar()