#CNN+LSTM-CNN model
This is the code for the CNN+LSTM generator and CNN discriminator

In [None]:

import os
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
os.chdir('/content/drive/My Drive/c147_project/')


Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import os
import time
from data import eegData
import numpy as np
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2
# get the device type of machine
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_dataset = eegData('data/X_train_valid.npy', 'data/y_train_valid.npy', device, preprocessing_params={'subsample':6, 'trim':400})

cuda


In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import os
import time
import numpy as np
import matplotlib.pyplot as plt
!pip install torchinfo --quiet
import torchinfo

def count_parameters(model):
  """Function for count model's parameters"""
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class CNNLSTMGenerator(nn.Module):
  def __init__(self,input_dim,hidden_dims):
    super(CNNLSTMGenerator,self).__init__()
    self.input_dim = input_dim
    self.l1 = nn.Sequential(
        nn.ConvTranspose1d(self.input_dim,hidden_dims[0],4,2,0, bias = False),
        nn.BatchNorm1d(hidden_dims[0]),
        nn.ReLU(True)
    )
    self.l2 = nn.Sequential(
        nn.ConvTranspose1d(hidden_dims[0],hidden_dims[1],4,2,0, bias = False),
        nn.BatchNorm1d(hidden_dims[1]),
        nn.ReLU(True)
    )
    self.l3 = nn.Sequential(
        nn.ConvTranspose1d(hidden_dims[1],hidden_dims[2],4,2,0, bias = False),
        nn.BatchNorm1d(hidden_dims[2]),
        nn.ReLU(True)
    )
    self.l4 = nn.Sequential(
        nn.ConvTranspose1d(hidden_dims[2],hidden_dims[3],7,2,0, bias = False),
        nn.BatchNorm1d(hidden_dims[3]),
        nn.ReLU(True)
    )
    self.l5 = nn.Sequential(
        nn.ConvTranspose1d(hidden_dims[3],10,4,2,0, bias = False),
        nn.ReLU(True)
    )
    self.l6 = nn.LSTM(10,22,2, batch_first = True) 


  def forward(self,input):
    l1 = self.l1(input)
    l2 = self.l2(l1)
    l3 = self.l3(l2)
    l4 = self.l4(l3)
    l5 = self.l5(l4)
    l6, hidden = self.l6(l5.view(l5.shape[0],l5.shape[2],l5.shape[1]))
    l6 = l6.reshape(l6.shape[0],l6.shape[2],l6.shape[1])
    return l6



In [None]:
class CNNDiscriminator(nn.Module):
  """
  Discriminator that uses CNN layers
  follows the DCGAN of the pytorch tutorial
  https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
  input is (batch, in_dim, seq)
  output is (batch, 1, seq)
  """
  def __init__(self, input_dim, hidden_dims):
    super(CNNDiscriminator, self).__init__()
    assert len(hidden_dims) == 4
    self.conv1 = nn.Sequential(
        nn.Conv1d(input_dim, hidden_dims[0], 3, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True)
    ) # output batch x dim1 x 50
    self.conv2 = nn.Sequential(
        nn.Conv1d(hidden_dims[0], hidden_dims[1], 3, 2, 1, bias=False),
        nn.BatchNorm1d(hidden_dims[1]),
        nn.LeakyReLU(0.2, inplace=True),
    ) # ouptut batch x dim2 x 25
    self.conv3 = nn.Sequential(
        nn.Conv1d(hidden_dims[1], hidden_dims[2], 3, 2, 1, bias=False),
        nn.BatchNorm1d(hidden_dims[2]),
        nn.LeakyReLU(0.2, inplace=True),
    )# ouptut batch x dim3 x 13

    self.conv4 = nn.Sequential(
        nn.Conv1d(hidden_dims[2], hidden_dims[3], 3, 2, 1, bias=False),
        nn.BatchNorm1d(hidden_dims[3]),
        nn.LeakyReLU(0.2, inplace=True),
    )# ouptut batch x dim3 x 7

    self.end = nn.Sequential(
        nn.Conv1d(hidden_dims[3], 1, 3, 2, 1, bias=False),
        nn.BatchNorm1d(1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv1d(1, 1, 4, 1, 0, bias=False),
        nn.Sigmoid()
    )



  def forward(self, x):
    return self.end(self.conv4(self.conv3(self.conv2(self.conv1(x)))))

In [None]:
class GAN(object):
  """
  Class object to control and abstract the training and logging of vanilla GANs
  Assumes that the noise input or z vector is (batch_size, 2, 600) 
    where the first channel is noise and the second channel is the label
  """
  def __init__(self,
               generator: nn.Module,
               discriminator: nn.Module,
               train_valid_data: eegData,
               **kwargs
               ):
    self.G = generator
    self.D = discriminator
    self.loss = nn.BCELoss()
    self.dataset = train_valid_data
    # checking for CUDA acceleration
    # move parameters to Cuda
    # establish before the optimizer and loss
    self.cuda_bool = kwargs.get('cuda', False)
    self.cuda_device = kwargs.get('device', 0)
    self.device = torch.device('cpu')
    if self.cuda_bool:
      print("Establishing network on CUDA device: ", torch.cuda.get_device_name(self.cuda_device))
      self.device = torch.device('cuda')
      self.G.cuda(self.cuda_device)
      self.D.cuda(self.cuda_device)
      self.loss.cuda(self.cuda_device)
    
    # setting up optimizers
    lr = kwargs.get('learn_rate', 0.0002)
    w_decay = kwargs.get('weight_decay', 0.00001)
    self.d_optimizer = torch.optim.Adam(self.D.parameters(), lr=lr, weight_decay=w_decay)
    self.g_optimizer = torch.optim.Adam(self.G.parameters(), lr=lr, weight_decay=w_decay)


  def train(self, 
            epochs, batch_size, save_path, verbose=True, print_every=100):
    
    
    self.epochs = epochs

    self.batch_size = batch_size

    self.G_loss = list()
    self.D_loss = list()
    self.generated_test = list()

    # setup the loader
    data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

    iteration = 0

    # TODO: change the input
    fixed_z = torch.rand((1, self.G.input_dim - 1, 1), device=self.device)
    fixed_labels = torch.zeros((1 , 1), device=self.device).reshape(-1, 1).repeat(1, 1).reshape(-1, 1, 1)
    fixed_z = torch.cat([fixed_z, fixed_labels], dim=1).to(self.device)
    real_label = 1
    fake_label = 0
    for epoch in range(self.epochs + 1):
      for i, sample in enumerate(data_loader):
        # What to do for each batch

        if i == data_loader.dataset.__len__() // self.batch_size:
          break

        self.G.train()

        # assert data and labels
        eeg_data = sample['data']
        eeg_labels = sample['label']

        # input vector for 
        noise = torch.rand((self.batch_size, self.G.input_dim - 1, 1), device=self.device)
        labels = eeg_labels.reshape(-1, 1).repeat(1, 1).reshape(-1, 1, 1)
        
        z = torch.cat([noise, labels], dim=1)

        seq_length = eeg_data.shape[2]

        benchmark = torch.full((self.batch_size, 1, 1), real_label, device=self.device, dtype=torch.float)

        # attach to cuda
        if self.cuda_bool:
          eeg_data = Variable(eeg_data.cuda(self.cuda_device))
          z = Variable(z.cuda(self.cuda_device))
        else:
          eeg_data = Variable(eeg_data)
          z = Variable(z)
        

        ##############
        # Training the Discriminator
        ##############
        self.d_optimizer.zero_grad()
        self.D.zero_grad() # remove previous gradients
        # train the discriminator on real data
        real_score = self.D(eeg_data)
        d_loss_real = self.loss(real_score, benchmark)
        d_loss_real.backward()
        real_output_score = real_score.mean().item()

        # train the discriminator on fake data

        benchmark.fill_(fake_label)
        fake_data = self.G(z)
        fake_score = self.D(fake_data.detach())
        d_loss_fake = self.loss(fake_score, benchmark)
        d_loss_fake.backward()
        # print(fake_score)
        G_output_score1 = fake_score.mean().item()
        

        # optimize the discriminator
        d_loss = d_loss_real + d_loss_fake
        self.d_optimizer.step()


        ##############
        # Training the Generator
        ##############
        self.G.zero_grad() # remove previous gradients
        self.g_optimizer.zero_grad()

        benchmark.fill_(real_label)
        fake_score2 = self.D(fake_data)
        G_output_score2 = fake_score2.mean().item()
        g_loss = self.loss(fake_score2, benchmark)
        g_loss.backward()
        self.g_optimizer.step()

        iteration += 1
        
        if iteration % print_every == 0 and verbose:
          # output the loss and the scores
          print("Iteration : ", iteration)
          print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tReal Score: %.4f\t Fake Scores: %.4f / %.4f'
                  % (epoch, self.epochs, i, len(data_loader),
                     d_loss.item(), g_loss.item(), real_output_score, G_output_score1, G_output_score2))
          
        self.G_loss.append(g_loss.cpu().item())
        self.D_loss.append(d_loss.cpu().item())

      # check how the Generator is doing
      with torch.no_grad() :
        self.G.eval()
        fake_data = self.G(fixed_z).detach().cpu()
        self.generated_test.append(np.mean(fake_data.numpy(), axis=(0, 1)))


      if epoch % 250 == 0:
        save_str = save_path + '/Generator_' + str(epoch) + '.pth'
        self.save_models(save_str)


  def save_models(self, path):
    torch.save(self.G.state_dict(), path)
    print("Saved generator at " + path)
    return True

In [None]:
testinp = torch.rand((1,99,1))
test_label = torch.zeros((1,1,1))
inp = torch.cat([testinp, test_label], dim = 1)
Gen = CNNLSTMGenerator(100,[220,154,88,44])
print(inp.shape)
x = Gen(inp)
dis = CNNDiscriminator(22,[18, 14, 10 , 6])

gan = GAN(Gen,dis,train_dataset, cuda = True)

torch.Size([1, 100, 1])


In [None]:
gan.train(5000,250, save_path='/content/drive/My Drive/c147_project/Models/CNN+LSTM-CNN')

Saved generator at /content/drive/My Drive/c147_project/Models/CNN+LSTM-CNN/Generator_0.pth
Iteration :  100
[1/5000][49/51]	Loss_D: 1.1579	Loss_G: 0.9055	Real Score: 0.5756	 Fake Scores: 0.4407 / 0.4045
Iteration :  200
[3/5000][49/51]	Loss_D: 1.7068	Loss_G: 0.3985	Real Score: 0.5895	 Fake Scores: 0.6871 / 0.6715
Iteration :  300
[5/5000][49/51]	Loss_D: 1.6843	Loss_G: 0.3802	Real Score: 0.5971	 Fake Scores: 0.6840 / 0.6838
Iteration :  400
[7/5000][49/51]	Loss_D: 1.6359	Loss_G: 0.4007	Real Score: 0.5980	 Fake Scores: 0.6700 / 0.6699
Iteration :  500
[9/5000][49/51]	Loss_D: 1.5791	Loss_G: 0.4191	Real Score: 0.6080	 Fake Scores: 0.6577 / 0.6576
Iteration :  600
[11/5000][49/51]	Loss_D: 1.5595	Loss_G: 0.4359	Real Score: 0.6004	 Fake Scores: 0.6468 / 0.6467
Iteration :  700
[13/5000][49/51]	Loss_D: 1.5272	Loss_G: 0.4514	Real Score: 0.6021	 Fake Scores: 0.6368 / 0.6367
Iteration :  800
[15/5000][49/51]	Loss_D: 1.5053	Loss_G: 0.4658	Real Score: 0.6001	 Fake Scores: 0.6277 / 0.6276
Iteration

In [None]:
gan.generated_test[-1].shape

In [None]:
plt.plot(gan.G_loss, label='Generator Loss')
plt.plot(gan.D_loss, label='Discriminator Loss')
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.title("CNN+LSTM-CNN GAN loss diagram")

In [None]:
# to run please add the hidden variable to the output of the generator
batch = 250
Gtest = Gen
Gtest.to(device)
print(torchinfo.summary(Gtest, input_size=(250, 100, 1)))
Dtest = dis
Dtest.to(device)
print(torchinfo.summary(Dtest, input_size=(250, 22, 100)))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [250, 220, 4]             --
|    └─ConvTranspose1d: 2-1              [250, 220, 4]             88,000
|    └─BatchNorm1d: 2-2                  [250, 220, 4]             440
|    └─ReLU: 2-3                         [250, 220, 4]             --
├─Sequential: 1-2                        [250, 154, 10]            --
|    └─ConvTranspose1d: 2-4              [250, 154, 10]            135,520
|    └─BatchNorm1d: 2-5                  [250, 154, 10]            308
|    └─ReLU: 2-6                         [250, 154, 10]            --
├─Sequential: 1-3                        [250, 88, 22]             --
|    └─ConvTranspose1d: 2-7              [250, 88, 22]             54,208
|    └─BatchNorm1d: 2-8                  [250, 88, 22]             176
|    └─ReLU: 2-9                         [250, 88, 22]             --
├─Sequential: 1-4                        [250, 44, 49]             --