## Imports

In [1]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from tqdm import tqdm

## Configuration

In [None]:
# Set the plotting style to seaborn instead of matplotlib
sns.set_theme()

# Set the random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set the device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

# Free up cuda cache on reruns
if device == 'cuda':
    torch.cuda.empty_cache()

Using device: cuda


## Data Loading

## Model Design

In [None]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.block = nn.Sequential([nn.Conv2d(in_channels, out_channels, kernel_size=3),
                                    nn.ReLU(inplace=True),
                                    nn.MaxPool2d(2)])

class ConvNet(nn.Module):
    def __init__(self, in_channels=3):
        super(ConvNet, self).__init__()
        self.block = nn.Sequential([ConvBlock(in_channels, 32),
                                    ConvBlock(32, 64),
                                    ConvBlock(64, 128),
                                    ConvBlock(128, 256),
                                    nn.Flatten()])
    
    def forward(self, x):
        return self.block(x)
    
class RecurrentBlock(nn.Module):
    def __init__(self, sequence_length=10):
        super(RecurrentBlock, self).__init__()
        self.hidden = None
        self.lstm = nn.LSTM(256, 128, bidirectional=True, num_layers=sequence_length)
    
    def forward(self, x):
        output, self.hidden = self.lstm(x, self.hidden)
        return output
    

class FearBlock(nn.Module):
    def __init__(self):
        super(FearBlock, self).__init__()
        self.cnn = ConvNet()
        self.rnn = RecurrentBlock()
        self.tail = nn.Sequential([nn.Linear(128, 64),
                                   nn.ReLU(),
                                   nn.Dropout(0.5),
                                   nn.Linear(64, 5),
                                   nn.Sigmoid()])
    
    def forward(self, x):
        x = self.cnn(x)
        self.rnn(x)
        x = self.rnn.hidden[0][-1]
        x = self.tail(x)
        return x
        
