In [1]:
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader

from data import DrawingDataset, collate_batch

# from mamba import Mamba, MambaConfig
from mamba_ssm import Mamba
# from mamba_ssm.modules.mamba_simple import Block
from mamba_ssm.models.mixer_seq_simple import create_block

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
print(torch.cuda.get_device_properties(0))

True
NVIDIA GeForce RTX 3050 Ti Laptop GPU
_CudaDeviceProperties(name='NVIDIA GeForce RTX 3050 Ti Laptop GPU', major=8, minor=6, total_memory=4095MB, multi_processor_count=20)


In [3]:
train_dataset = DrawingDataset(data_path="./data", split="train", max_length=100)
val_dataset = DrawingDataset(data_path="./data", split="valid", max_length=100)
test_dataset = DrawingDataset(data_path="./data", split="test", max_length=100)

train = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)
val = DataLoader(dataset=val_dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)
test = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, collate_fn=collate_batch)

  0%|          | 1/345 [00:00<01:38,  3.49it/s]
  0%|          | 1/345 [00:00<00:03, 99.41it/s]
  0%|          | 1/345 [00:00<00:03, 95.13it/s]


In [4]:
# class mambaBlock(nn.Module):
#     def __init__(self, d_model, d_state, d_conv, expand, n_layers):
#         super(mambaBlock, self).__init__()
#         self.d_model = d_model
#         self.d_state = d_state
#         self.d_conv = d_conv
#         self.expand = expand
#         self.n_layers = n_layers
#         self.layers = []

#         self.layers = nn.ModuleList([Mamba(d_model, d_state, d_conv, expand) for _ in range(n_layers)])

#     def forward(self, x):
#         for layer in self.layers:
#             x = layer(x)
#         return x

class customModel(nn.Module):
    def __init__(self, n_layers):
        super(customModel, self).__init__()
        # self.d_model = d_model
        self.n_layers = n_layers
        # self.m1 = mambaBlock(5, d_state, d_conv, expand, n_layers)
        # self.leftm = mambaBlock(5, d_state, d_conv, expand, n_layers)
        # self.rightm = mambaBlock(5, d_state, d_conv, expand, n_layers)
        self.m1 = nn.ModuleList([create_block(5, device='cuda', layer_idx='i') for i in range(n_layers)])
        self.leftm = nn.ModuleList([create_block(5, device='cuda', layer_idx='i') for i in range(n_layers)])
        self.rightm = nn.ModuleList([create_block(5, device='cuda', layer_idx='i') for i in range(n_layers)])

    def forward(self, x): # x is of shape (B, L, 5) (Batchsize, sequence length, dimension)
        # x = self.m1(x)
        hidden_states, residuals = x, None
        for layer in self.m1:
            hidden_states, residuals = layer(hidden_states, residuals)
        
        left_hidden_states, left_residuals = hidden_states, residuals
        right_hidden_states, right_residuals = hidden_states, residuals

        for layer in self.leftm:
            left_hidden_states, left_residuals = layer(left_hidden_states, left_residuals)
        for layer in self.rightm:
            right_hidden_states, right_residuals = layer(right_hidden_states, right_residuals)
        return left_hidden_states[:, :, :2], right_hidden_states[:, :, 2:]

In [8]:
model = customModel(n_layers=4).to("cuda")

offset_crit = nn.MSELoss()
state_crit = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print_every = 500
epochs = 5

for epoch in range(epochs):
    loss = None
    offset_loss = None
    state_loss = None
    for data in tqdm(train):
    # for i, data in enumerate(train):
        inputs, targets = data
        inputs = inputs.to("cuda")
        targets = targets.to("cuda")
        optimizer.zero_grad()
        offsets, states = model(inputs)
        offset_loss = offset_crit(offsets, targets[:, :, :2])
        state_loss = state_crit(states, targets[:, :, 2:])
        loss = offset_loss + state_loss
        loss.backward()
        optimizer.step()
        # if i % print_every == 0:
        #     print(f"Epoch: {epoch+1}, Batch: {i}/{len(train)}, Loss: {loss}, offset_loss: {offset_loss}, state_loss: {state_loss}")
    print(f"Epoch: {epoch+1}, Loss: {loss}, offset_loss: {offset_loss}, state_loss: {state_loss}")

  0%|          | 0/1094 [00:00<?, ?it/s]

100%|██████████| 1094/1094 [00:26<00:00, 40.80it/s]


Epoch: 1, Loss: 1073.641845703125, offset_loss: 940.0738525390625, state_loss: 133.56805419921875


100%|██████████| 1094/1094 [00:26<00:00, 40.98it/s]


Epoch: 2, Loss: 987.791015625, offset_loss: 856.3636474609375, state_loss: 131.42738342285156


100%|██████████| 1094/1094 [00:29<00:00, 36.60it/s]


Epoch: 3, Loss: 968.4552612304688, offset_loss: 836.7730712890625, state_loss: 131.6822052001953


100%|██████████| 1094/1094 [00:30<00:00, 35.81it/s]


Epoch: 4, Loss: 1494.5880126953125, offset_loss: 1363.9757080078125, state_loss: 130.61233520507812


100%|██████████| 1094/1094 [00:32<00:00, 33.16it/s]

Epoch: 5, Loss: 1000.082763671875, offset_loss: 869.1807250976562, state_loss: 130.9020538330078





True