In [None]:
import numpy as np
import pandas as pd
import os

In [None]:
def get_file_list(path_list, tar_dirs):
    list_file = []
    for notation, tar_dir in zip(path_list, tar_dirs):
        with open(notation,'r') as f:
            lines = f.readlines()
        for line in lines:
            print
            list_file.append(tar_dir + '/'.join(line.split('/')[2:])[:-7] + '.txt')
    return list_file

In [None]:
def load_angles(file_path):
    angles_list = []
    with open(file_path, 'r') as file:
        for line in file:
            try:
                frame_angles = list(map(float, line.strip().split(',')))
                angles_list.append(frame_angles)
            except:
                continue
    return angles_list

In [None]:
train_list = get_file_list(['/kaggle/input/rocog-v2/annotations/syn_ground_train.txt',
                            '/kaggle/input/rocog-v2/annotations/syn_air_train.txt',
                            '/kaggle/input/rocog-v2/annotations/real_air_train.txt',
                            '/kaggle/input/rocog-v2/annotations/real_ground_train.txt'],
                           ['/kaggle/input/dsp-rocog/syn_ground/syn_ground/',
                            '/kaggle/input/dsp-rocog/syn_air/syn_air/',
                            '/kaggle/input/rocog-v2/real/air/',
                            '/kaggle/input/rocog-v2/real/ground/'])

In [None]:
test_list = get_file_list(['/kaggle/input/rocog-v2/annotations/real_ground_test.txt',
                           '/kaggle/input/rocog-v2/annotations/real_air_test.txt'],
                           ['/kaggle/input/rocog-v2/real/ground/',
                            '/kaggle/input/rocog-v2/real/air/'])

In [None]:
import sys
sys.setrecursionlimit(100000)
sys.getrecursionlimit()

In [None]:
!pip install -qq mamba-ssm 
!pip install -qq causal-conv1d>=1.2.0

In [None]:
!git clone https://github.com/alxndrTL/mamba.py.git
%cd /kaggle/working/mamba.py

In [None]:
label_map = {
    'Advance': 0, 
    'Attention': 1,
    'Rally': 2, 
    'MoveForward': 3, 
    'Halt': 4,
    'FollowMe': 5, 
    'MoveInReverse': 6
}

In [None]:
import torch 
import torch.nn as nn 
import torch.optim as optim 
from torch.utils.data import Dataset, DataLoader

In [None]:
torch.eye(len(label_map))[4]

In [None]:
def pad_and_resize(sequence, max_length = 256):
    if len(sequence) <= max_length:
        pad_amount = max_length - len(sequence)
        padded_sequence = torch.nn.functional.pad(sequence, (0, 0, 0, pad_amount), mode='constant', value=0)
        return padded_sequence
    else:
        sequence = sequence.unsqueeze(0).permute(0, 2, 1)
        resized_sequence = torch.nn.functional.interpolate(sequence, 
                                                           size=max_length, 
                                                           mode='linear', 
                                                           align_corners=False).permute(0, 2, 1).squeeze(0)
        return resized_sequence

In [None]:
class RoCogDataset(Dataset):
    def __init__(self, list_path, max_len = 256, label_map = label_map):
        self.data_dir_list = list_path
        self.max_len = max_len
        self.label_map = label_map
        self.num_label = len(label_map)
        
    def __len__(self):
        return len(self.data_dir_list)
    
    def __getitem__(self, idx):
        label = torch.eye(self.num_label)[self.label_map[self.data_dir_list[idx].split('/')[-2]]]
        
        try:
            inputs = torch.Tensor(load_angles(self.data_dir_list[idx]))
        except:
            return self.__getitem__(0)
        if len(inputs) == 0:
            return self.__getitem__(0)
        inputs = pad_and_resize(inputs)
        return inputs, label

In [None]:
train_dataset = RoCogDataset(train_list)
test_dataset = RoCogDataset(test_list)

In [None]:
from mambapy.mamba import Mamba, MambaConfig

In [None]:
class CosineSimilarityClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(CosineSimilarityClassifier, self).__init__()
        self.num_classes = num_classes
        self.Wstar_layer = nn.Linear(input_size, num_classes, bias=False)

    def forward(self, x):
        norm_z = x / torch.norm(x, p=2, dim=1, keepdim=True)
        
        Wstar = self.Wstar_layer.weight.T
        norm_Wstar = Wstar / torch.norm(Wstar, p=2, dim=0, keepdim=True)
        
        cosine_similarities = torch.mm(norm_z, norm_Wstar)
        
        return cosine_similarities

In [None]:
class MambaClassification(nn.Module):
    def __init__(self, input_dim=14, embed_dim=64, d_state=16, seq_len = 256, 
                 d_conv=4, expand=2, n_class=7, n_layers=2, n_heads=2):
        super(MambaClassification, self).__init__()
        self.n_heads = n_heads
        self.relu = nn.ReLU(inplace = True)
        self.config = MambaConfig(d_model=embed_dim, n_layers=n_layers)
        self.mamba1 = Mamba(self.config)
        self.mamba2 = Mamba(self.config)
        self.mamba3 = Mamba(self.config)
        self.mamba4 = Mamba(self.config)
        self.flatten = nn.Flatten()
        self.cosine = CosineSimilarityClassifier(4*embed_dim*seq_len, n_class)
        
    def forward(self, x):
        x1 = self.flatten(self.mamba1(x))
        x1 = torch.cat((x1, self.flatten(self.mamba2(x))), dim=1)
        x1 = torch.cat((x1, self.flatten(self.mamba3(x))), dim=1)
        x1 = torch.cat((x1, self.flatten(self.mamba4(x))), dim=1)
        
        x = self.cosine(x1)
        x = self.relu(x)
        return x

In [None]:
model = MambaClassification(embed_dim = 14, n_layers = 1, n_heads=4)

In [None]:
print(sum(p.numel() for p in model.parameters()))

In [None]:
model = torch.nn.parallel.DataParallel(model).to('cuda')

In [None]:
criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

In [None]:
!wandb login 66e50942eccd47b2d4c5e98039604b3f80921c27

In [None]:
import wandb

wandb.init(
    project="dsp-rocog",

    # track hyperparameters and run metadata
    config={
        "train_loss": 3,
        "train_acc": 0,
        "val_loss": 2,
        "val_acc": 0
    }
)

In [None]:
import gc
from tqdm import tqdm

last_acc = 0
for epoch in range(100):
    total_step = len(train_loader)
    pbar = tqdm(total=total_step)
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        output = model(inputs)
        loss = criterion(output, labels) 
        
        optimizer.zero_grad()
        loss.backward() 
        optimizer.step()
        predicted = torch.argmax(output.data, dim=1)
        labels = torch.argmax(labels, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del inputs, labels, output
        torch.cuda.empty_cache()
        gc.collect()
        wandb.log({"train_loss": loss.item(), 'train_acc': correct/total})
        pbar.update(1)
    pbar.close()
    print('Epoch',epoch,' Loss -->',loss.item(), 'Train acc -->', correct/total)
    
    with torch.no_grad():
        correct = 0
        total = 0
        losses = 0
        for i, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to('cuda')
            labels = labels.to('cuda')
            output = model(inputs)
            loss = criterion(output, labels)
            losses += loss.item()
            
            predicted = torch.argmax(output.data, dim=1)
            labels = torch.argmax(labels, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            del inputs, labels, output
            torch.cuda.empty_cache()
            gc.collect()
        
        test_acc = correct/total
        wandb.log({"val_loss": losses/len(test_loader), 'val_acc': test_acc})
        print('Epoch',epoch,' Loss -->',losses/len(test_loader), 'Test acc -->', test_acc)
        if test_acc > last_acc:
            torch.save(model.state_dict(), f'/kaggle/working/model_epoch_{epoch}.pth')
            print('Weight saved')
            last_acc = test_acc
            
wandb.finish()

In [None]:
pbar.close()