In [25]:
from tqdm.auto import tqdm
from tqdm.utils import _term_move_up

prefix = _term_move_up() + '\r'

import torch
from torch import nn
from torch.utils.data import DataLoader

from data import DrawingDataset

# from mamba import Mamba, MambaConfig
from mamba_ssm import Mamba
# from mamba_ssm.modules.mamba_simple import Block
from mamba_ssm.models.mixer_seq_simple import create_block

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
print(torch.cuda.get_device_properties(0))

True
NVIDIA GeForce GTX 1080
_CudaDeviceProperties(name='NVIDIA GeForce GTX 1080', major=6, minor=1, total_memory=8112MB, multi_processor_count=20)


In [27]:
train_dataset = DrawingDataset(data_path="./data", split="train", max_length=50)
val_dataset = DrawingDataset(data_path="./data", split="valid", max_length=50)
test_dataset = DrawingDataset(data_path="./data", split="test", max_length=50)

train = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True)
val = DataLoader(dataset=val_dataset, batch_size=256, shuffle=True)
test = DataLoader(dataset=test_dataset, batch_size=256, shuffle=True)

  0%|          | 0/345 [00:00<?, ?it/s]

  7%|▋         | 25/345 [00:10<02:17,  2.33it/s]
  7%|▋         | 25/345 [00:00<00:05, 57.37it/s]
  7%|▋         | 25/345 [00:00<00:05, 59.03it/s]


In [28]:
# class mambaBlock(nn.Module):
#     def __init__(self, d_model, d_state, d_conv, expand, n_layers):
#         super(mambaBlock, self).__init__()
#         self.d_model = d_model
#         self.d_state = d_state
#         self.d_conv = d_conv
#         self.expand = expand
#         self.n_layers = n_layers
#         self.layers = []

#         self.layers = nn.ModuleList([Mamba(d_model, d_state, d_conv, expand) for _ in range(n_layers)])

#     def forward(self, x):
#         for layer in self.layers:
#             x = layer(x)
#         return x

class customModel(nn.Module):
    def __init__(self, nb, no, ns, embed_dim, state_hidden):
        super(customModel, self).__init__()
        
        self.embed_dim = embed_dim
        self.state_hidden = state_hidden
        self.proj = nn.Linear(in_features=5, out_features=self.embed_dim)
        
        self.m1 = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'm{i}') for i in range(nb)])
        self.leftm = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'l{i}') for i in range(no)])
        self.rightm = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'r{i}') for i in range(ns)])
        
        self.offset_out = nn.Linear(in_features=self.embed_dim, out_features=2)
        self.state_out = nn.Linear(in_features=self.embed_dim, out_features=3)

    def forward(self, x): # x is of shape (B, L, 5) (Batchsize, sequence length, dimension)
        # x = self.m1(x)
        #x = self.proj(x)
        hidden_states, residuals = x, None
        for layer in self.m1:
            hidden_states, residuals = layer(hidden_states, residuals)
        
        left_hidden_states, left_residuals = hidden_states, residuals
        right_hidden_states, right_residuals = hidden_states, residuals

        for layer in self.leftm:
            left_hidden_states, left_residuals = layer(left_hidden_states, left_residuals)
        for layer in self.rightm:
            right_hidden_states, right_residuals = layer(right_hidden_states, right_residuals)
        
        offset_out = self.offset_out(left_hidden_states)
        state_out = self.state_out(right_hidden_states)
        return offset_out, state_out

In [29]:
from torch.utils.tensorboard import SummaryWriter
import time

model = customModel(nb=4, no=2, ns=2, embed_dim=5, state_hidden=128).to("cuda")

offset_crit = nn.MSELoss()
state_crit = nn.CrossEntropyLoss()

optimizer = torch.optim.RAdam(model.parameters(), lr=5e-4)

writer = SummaryWriter('./logs')

log_interval = 25
epochs = 25

def train_model(model, data_loader, optimizer, epoch):
    model.train()
    
    size = len(data_loader)
    
    # Total Losses
    total_loss = 0
    total_offset_loss = 0
    total_state_loss = 0
    
    # Running Losses
    running_loss = 0
    running_offset_loss = 0
    running_state_loss = 0
    
    start_time = time.time()
    
    for i, data in enumerate(tqdm(data_loader)):
        optimizer.zero_grad()
        
        inputs, targets = data
        inputs = inputs.to("cuda")
        targets = targets.to("cuda")
        
        offsets, states = model(inputs)
        offset_loss = offset_crit(offsets, targets[:, :, :2])
        
        state_target = targets[:, :, 2:]
        #print(states.transpose(1, 2).shape, state_target.argmax(dim=-1).shape)
        
        
        state_loss = state_crit(states.transpose(1, 2), state_target.argmax(dim=-1))
        loss = offset_loss + state_loss
        
        loss.backward(retain_graph=True)
        
        for name, param in model.named_parameters():
            torch.nn.utils.clip_grad_norm_(param, max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        running_offset_loss += offset_loss.item()
        running_state_loss += state_loss.item()
        
        total_loss += loss.item()
        total_offset_loss += offset_loss.item()
        total_state_loss += state_loss.item()
        
        if i % log_interval == 0 and i > 0:
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = running_loss / log_interval
            cur_offset_loss = running_offset_loss / log_interval
            cur_state_loss = running_state_loss / log_interval
            tqdm.write(f'{prefix}| epoch {(epoch+1):3d} | {i:5d}/{size:5d} batches | '
                  f'| ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | offset {cur_offset_loss:5.2f} | state {cur_state_loss:5.2f}')
            time.sleep(0)
            running_loss = 0
            running_offset_loss = 0
            running_state_loss = 0
            start_time = time.time()
    
    return total_loss / size, total_offset_loss / size, total_state_loss / size
        
def evaluate_model(model, data_loader):
    model.eval()
    size = len(data_loader)
    
    # Running Losses
    running_loss = 0
    running_offset_loss = 0
    running_state_loss = 0
    
    with torch.no_grad():
        for data in tqdm(data_loader):
            
            inputs, targets = data
            inputs = inputs.to("cuda")
            targets = targets.to("cuda")
            
            offsets, states = model(inputs)
            offset_loss = offset_crit(offsets, targets[:, :, :2])
            state_loss = state_crit(states.transpose(1, 2), targets[:, :, 2:].argmax(dim=-1))
            loss = offset_loss + state_loss
            
            running_loss += loss.item()
            running_offset_loss += offset_loss.item()
            running_state_loss += state_loss.item()
    
    return running_loss / size, running_offset_loss / size, running_state_loss / size
            


In [30]:
torch.autograd.set_detect_anomaly(True)
for epoch in range(epochs):
    train_loss, train_offset_loss, train_state_loss = train_model(model, train, optimizer, epoch)
    print(f"Training: Epoch: {epoch+1}, Loss: {train_loss}, offset_loss: {train_offset_loss}, state_loss: {train_state_loss}")
    writer.add_scalar("Train/Loss/Epoch", train_loss, epoch)
    writer.add_scalar("Train/Offset_Loss/Epoch", train_offset_loss, epoch)
    writer.add_scalar("Train/State_Loss/Epoch", train_state_loss, epoch)
    
    val_loss, val_offset_loss, val_state_loss = evaluate_model(model, val)
    print(f"Validation: Epoch: {epoch+1}, Loss: {val_loss}, offset_loss: {val_offset_loss}, state_loss: {val_state_loss}")
    writer.add_scalar("Train/Loss/Epoch", val_loss, epoch)
    writer.add_scalar("Train/Offset_Loss/Epoch", val_offset_loss, epoch)
    writer.add_scalar("Train/State_Loss/Epoch", val_state_loss, epoch)

  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   0 |  3700/ 3708 batches | | ms/batch 58.82 | loss 1457.67 | offset 1457.38 | state  0.29
Training: Epoch: 1, Loss: 1768.4182645433543, offset_loss: 1767.9618721214056, state_loss: 0.45639056714800186


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 1, Loss: 1445.9902458477736, offset_loss: 1445.6989755271968, state_loss: 0.2912678057537939


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   1 |  3700/ 3708 batches | | ms/batch 57.60 | loss 1315.62 | offset 1315.37 | state  0.25
Training: Epoch: 2, Loss: 1345.4794892410887, offset_loss: 1345.2172191829743, state_loss: 0.26227054736117844


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 2, Loss: 1269.883427440672, offset_loss: 1269.6329331935797, state_loss: 0.2504954796312447


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   2 |  3700/ 3708 batches | | ms/batch 57.30 | loss 1300.67 | offset 1300.41 | state  0.25
Training: Epoch: 3, Loss: 1252.6059913182594, offset_loss: 1252.3545406962933, state_loss: 0.2514506042678765


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 3, Loss: 1225.6350345468163, offset_loss: 1225.3840740461994, state_loss: 0.25096009940581215


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   3 |  3700/ 3708 batches | | ms/batch 57.73 | loss 1207.76 | offset 1207.51 | state  0.25
Training: Epoch: 4, Loss: 1221.876573910348, offset_loss: 1221.6242116621527, state_loss: 0.25236259104268066


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 4, Loss: 1196.5014414392915, offset_loss: 1196.2513735204711, state_loss: 0.250065427637638


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   4 |  3700/ 3708 batches | | ms/batch 58.18 | loss 1242.18 | offset 1241.93 | state  0.25
Training: Epoch: 5, Loss: 1204.9841408672971, offset_loss: 1204.733822506737, state_loss: 0.25031880347586627


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 5, Loss: 1183.6659357744948, offset_loss: 1183.417851727708, state_loss: 0.24808230413530105


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   5 |  3700/ 3708 batches | | ms/batch 58.60 | loss 1168.12 | offset 1167.87 | state  0.24
Training: Epoch: 6, Loss: 1192.219762824757, offset_loss: 1191.9722733976, state_loss: 0.24749032870947735


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 6, Loss: 1178.2000108302984, offset_loss: 1177.9522769325658, state_loss: 0.2477323152965173


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   6 |  3700/ 3708 batches | | ms/batch 60.55 | loss 1236.37 | offset 1236.12 | state  0.25
Training: Epoch: 7, Loss: 1180.73214362843, offset_loss: 1180.4866375804847, state_loss: 0.24550525552744876


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 7, Loss: 1165.639485524113, offset_loss: 1165.3950874500704, state_loss: 0.2443976633082655


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   7 |  3700/ 3708 batches | | ms/batch 57.88 | loss 1141.43 | offset 1141.18 | state  0.24
Training: Epoch: 8, Loss: 1171.2075934878187, offset_loss: 1170.9637359487458, state_loss: 0.24385760452883656


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 8, Loss: 1156.780706649436, offset_loss: 1156.5365926412712, state_loss: 0.2441131805342839


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   8 |  3700/ 3708 batches | | ms/batch 58.25 | loss 1148.94 | offset 1148.70 | state  0.24
Training: Epoch: 9, Loss: 1164.8490820661461, offset_loss: 1164.6064543163377, state_loss: 0.24262752283481076


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 9, Loss: 1156.0761223126174, offset_loss: 1155.8342610983025, state_loss: 0.24185735188928761


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch   9 |  3700/ 3708 batches | | ms/batch 58.97 | loss 1166.94 | offset 1166.70 | state  0.24
Training: Epoch: 10, Loss: 1158.1085921105444, offset_loss: 1157.8674163159944, state_loss: 0.24117566722303532


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 10, Loss: 1144.7788040046405, offset_loss: 1144.5390565341577, state_loss: 0.2397490367853552


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  10 |  3700/ 3708 batches | | ms/batch 59.23 | loss 1173.38 | offset 1173.14 | state  0.24
Training: Epoch: 11, Loss: 1153.5183515332662, offset_loss: 1153.2782744044496, state_loss: 0.24007609340300967


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 11, Loss: 1139.3343152497944, offset_loss: 1139.0954351210057, state_loss: 0.2388817337446643


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  11 |  3700/ 3708 batches | | ms/batch 57.28 | loss 1130.87 | offset 1130.63 | state  0.24
Training: Epoch: 12, Loss: 1148.9670894914975, offset_loss: 1148.7283274991098, state_loss: 0.2387616840807044


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 12, Loss: 1130.2671329813793, offset_loss: 1130.0290499809093, state_loss: 0.23808145500663527


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  12 |  3700/ 3708 batches | | ms/batch 57.06 | loss 1172.59 | offset 1172.35 | state  0.24
Training: Epoch: 13, Loss: 1144.9960328300665, offset_loss: 1144.7583068057556, state_loss: 0.23772604137000014


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 13, Loss: 1133.2657135698132, offset_loss: 1133.0285103016329, state_loss: 0.23720081038492963


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  13 |  3700/ 3708 batches | | ms/batch 57.51 | loss 1157.93 | offset 1157.69 | state  0.24
Training: Epoch: 14, Loss: 1141.7121162291098, offset_loss: 1141.474936492492, state_loss: 0.23717986705523092


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 14, Loss: 1130.6845216679394, offset_loss: 1130.4476355072252, state_loss: 0.23689169823227071


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  14 |  3700/ 3708 batches | | ms/batch 58.76 | loss 1138.89 | offset 1138.65 | state  0.24
Training: Epoch: 15, Loss: 1138.747296972182, offset_loss: 1138.510299880142, state_loss: 0.2369976050854372


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 15, Loss: 1123.0721779730086, offset_loss: 1122.8364560693726, state_loss: 0.23572466359998948


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  15 |  3700/ 3708 batches | | ms/batch 57.68 | loss 1126.05 | offset 1125.81 | state  0.24
Training: Epoch: 16, Loss: 1135.296761011743, offset_loss: 1135.06057291535, state_loss: 0.23618873447067273


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 16, Loss: 1119.6419962259163, offset_loss: 1119.406757096599, state_loss: 0.23523966620739242


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  16 |  3700/ 3708 batches | | ms/batch 57.62 | loss 1102.16 | offset 1101.93 | state  0.24
Training: Epoch: 17, Loss: 1132.847086226413, offset_loss: 1132.6114836099207, state_loss: 0.23560208945807387


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 17, Loss: 1118.0558315649964, offset_loss: 1117.8214400442023, state_loss: 0.23439358219616396


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  17 |  3700/ 3708 batches | | ms/batch 58.11 | loss 1148.38 | offset 1148.15 | state  0.23
Training: Epoch: 18, Loss: 1130.2511873708188, offset_loss: 1130.016273243399, state_loss: 0.23491403985740425


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 18, Loss: 1142.4805440113958, offset_loss: 1142.2457752658013, state_loss: 0.23476576446590566


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  18 |  3700/ 3708 batches | | ms/batch 58.36 | loss 1146.56 | offset 1146.33 | state  0.23
Training: Epoch: 19, Loss: 1126.93312986306, offset_loss: 1126.6987569041494, state_loss: 0.2343738216629499


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 19, Loss: 1115.418574053542, offset_loss: 1115.1851733214874, state_loss: 0.2334064207130805


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  19 |  3700/ 3708 batches | | ms/batch 58.02 | loss 1102.84 | offset 1102.61 | state  0.23
Training: Epoch: 20, Loss: 1125.0530293288741, offset_loss: 1124.8190666453866, state_loss: 0.23396251061298315


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 20, Loss: 1114.013727403225, offset_loss: 1113.7794730968046, state_loss: 0.2342489767343478


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  20 |  3700/ 3708 batches | | ms/batch 57.81 | loss 1128.72 | offset 1128.48 | state  0.23
Training: Epoch: 21, Loss: 1122.7658596923566, offset_loss: 1122.5321203270905, state_loss: 0.23373921801453656


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 21, Loss: 1109.5187607385162, offset_loss: 1109.2853058549695, state_loss: 0.23345347992459634


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  21 |  3700/ 3708 batches | | ms/batch 57.35 | loss 1157.38 | offset 1157.15 | state  0.23
Training: Epoch: 22, Loss: 1121.8339690174485, offset_loss: 1121.6005888674508, state_loss: 0.2333799599877839


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 22, Loss: 1111.5064339315084, offset_loss: 1111.2736600718104, state_loss: 0.23277464923553898


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  22 |  3700/ 3708 batches | | ms/batch 58.24 | loss 1164.40 | offset 1164.17 | state  0.23
Training: Epoch: 23, Loss: 1120.2083453988052, offset_loss: 1119.9753455663063, state_loss: 0.23300029555217092


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 23, Loss: 1109.6783658364661, offset_loss: 1109.4460293189027, state_loss: 0.2323349043166727


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  23 |  3700/ 3708 batches | | ms/batch 58.11 | loss 1118.69 | offset 1118.46 | state  0.23
Training: Epoch: 24, Loss: 1119.4704853703931, offset_loss: 1119.237662651778, state_loss: 0.23282221225296282


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 24, Loss: 1116.1122665978912, offset_loss: 1115.880062705592, state_loss: 0.2321959775417371


  0%|          | 0/3708 [00:00<?, ?it/s]

| epoch  24 |  3700/ 3708 batches | | ms/batch 58.09 | loss 1090.81 | offset 1090.58 | state  0.23
Training: Epoch: 25, Loss: 1118.7110567885052, offset_loss: 1118.478519674258, state_loss: 0.23253772671393722


  0%|          | 0/133 [00:00<?, ?it/s]

Validation: Epoch: 25, Loss: 1104.7901969278666, offset_loss: 1104.5585983391095, state_loss: 0.23159880006223693
