In [40]:
import torch
import numpy as np
from pathlib import Path
import logging
from torch.utils.data import DataLoader
from dataset import ViTacVisDataset
from torch.utils.tensorboard import SummaryWriter
import argparse
from torch import nn
import torch.nn.functional as F

In [41]:
# parser = argparse.ArgumentParser("Train model.")
# parser.add_argument("--epochs", type=int, help="Number of epochs.", required=True)
# parser.add_argument("--data_dir", type=str, help="Path to data.", required=True)
# parser.add_argument(
#     "--checkpoint_dir", type=str, help="Path for saving checkpoints.", required=True
# )

# parser.add_argument("--lr", type=float, help="Learning rate.", required=True)
# parser.add_argument(
#     "--sample_file", type=int, help="Sample number to train from.", required=True
# )
# parser.add_argument(
#     "--batch_size", type=int, help="Batch Size.", required=True
# )

In [42]:
class FLAGS():
    def __init__(self):
        self.data_dir = '/home/tasbolat/some_python_examples/data_VT_SNN/'
        self.batch_size = 8
        self.sample_file = 1
        self.lr = 0.01
        self.epochs = 100
        self.output_size = 20
args = FLAGS()

In [43]:
device = torch.device("cuda:1")

In [44]:
train_dataset = ViTacVisDataset(
    path=args.data_dir, sample_file=f"train_80_20_{args.sample_file}.txt", output_size=args.output_size
)
train_loader = DataLoader(
    dataset=train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4
)
test_dataset = ViTacVisDataset(
    path=args.data_dir, sample_file=f"test_80_20_{args.sample_file}.txt", output_size=args.output_size
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4
)

In [80]:
# define NN models
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=4, kernel_size=(5,5), stride=(2,2))
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(5,5), stride=(1,1))

    def forward(self, x):
#         print('CNN input:', x.shape)
        x = self.conv1(x)
        # print('Conv', x.size()) 
        x = F.max_pool2d(x, 2)
        #print('Pool 1', x.size())
        x = F.relu(x)
        
        x = self.conv2(x)
        # print('Conv', x.size()) 
        x = F.max_pool2d(x, 2)
        #print('Pool 2', x.size())
        x = F.relu(x)
        x = x.view(-1, 8*5*3)
        return x

In [87]:
class CNN_LSTM(nn.Module):

    def __init__(self):
        
        super(CNN_LSTM, self).__init__()
        self.input_size = 8*5*3
        self.hidden_dim = 32
        self.num_layers = 1
        
        self.cnn = CNN()

        # Define the LSTM layer
        self.gru = nn.GRU(self.input_size, self.hidden_dim, self.num_layers)

        # Define the output layer
        self.fc = nn.Linear(self.hidden_dim, 20)
        
        #self.fc_mlp = nn.Linear(6300, self.input_size)

    def forward(self, x):
        
        #print('Model input ', x.size())
        batch_size, C, H, W, sequence_size = x.size()
        
        # create CNN embedding
        cnn_embed_seq = []
        for t in range(sequence_size):
            cnn_out = self.cnn(x[...,t])
            cnn_embed_seq.append(cnn_out)
        
        cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0)
        #print('cnn_embed_seq: ', cnn_embed_seq.shape)
        
        # forward on GRU
        # (seq_len, batch, input_size)
        gru_out, hidden = self.gru(cnn_embed_seq)
        #print('gru out: ', gru_out.shape)
        
        
        # Only take the output from the final timetep
        # Can pass on the entirety of lstm_out to the next layer if it is a seq2seq prediction
        y_pred = self.fc(gru_out[-1, :, :])
        
        #print(y_pred.shape)
        return y_pred

In [88]:
net = CNN_LSTM().to(device)
# Create snn loss instance.
criterion = nn.CrossEntropyLoss()
# Define optimizer module.
optimizer = torch.optim.RMSprop(
    net.parameters(), lr=0.0001, weight_decay=0.5)

In [89]:
train_accs = []
test_accs = []
train_loss = []
test_loss = []
for epoch in range(args.epochs):
    # Training loop.
    net.train()
    correct = 0
    batch_loss = 0
    train_acc = 0
    for i, (in_viz, _, label) in enumerate(train_loader, 0):

        in_viz = in_viz.to(device)
        label = label.to(device)
        # Forward pass of the network.
        #print(in_viz.shape)
        out = net.forward(in_viz)
        #print(out_tact.shape)
        # Calculate loss.
        #print(label.shape)
        loss = criterion(out, label)
        #print(loss)

        batch_loss += loss.cpu().data.item()
        # Reset gradients to zero.
        optimizer.zero_grad()
        # Backward pass of the network.
        loss.backward()
        # Update weights.
        optimizer.step()

        _, predicted = torch.max(out.data, 1)
        correct += (predicted == label).sum().item()

    # Reset training stats.
    train_acc = correct/len(train_loader.dataset)
    train_loss.append(batch_loss)
    train_accs.append(train_acc)
    #print(train_acc, batch_loss)

    # testing
    net.eval()
    correct = 0
    batch_loss = 0
    test_acc = 0
    with torch.no_grad():
        for i, (in_viz, _, label) in enumerate(test_loader, 0):
            in_viz = in_viz.to(device)
            # Forward pass of the network.
            out = net.forward(in_viz)
            label = label.to(device)
            _, predicted = torch.max(out.data, 1)
            correct += (predicted == label).sum().item()
            # Calculate loss.
            loss = criterion(out, label)
            batch_loss += loss.cpu().data.item()

    test_loss.append(batch_loss)
    test_acc = correct/len(test_loader.dataset)
    test_accs.append(test_acc)
    if epoch%1 == 0:
        print(epoch, 'Train:', train_acc, 'Test:', test_acc)

0 Train: 0.05 Test: 0.05
1 Train: 0.05 Test: 0.05


KeyboardInterrupt: 

In [None]:
model_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

In [36]:
params

608724