In [1]:
import time
import math
from functools import partial
from tqdm.auto import tqdm
from tqdm.utils import _term_move_up

prefix = _term_move_up() + '\r'

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# from mamba import Mamba, MambaConfig
from mamba_ssm import Mamba
# from mamba_ssm.modules.mamba_simple import Block
from mamba_ssm.models.mixer_seq_simple import create_block, _init_weights

from transformers import get_cosine_schedule_with_warmup

from data import DrawingDataset

%load_ext autoreload
%autoreload 2

2024-03-06 10:01:25.935749: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
print(torch.cuda.get_device_properties(0))

True
NVIDIA GeForce GTX 1080
_CudaDeviceProperties(name='NVIDIA GeForce GTX 1080', major=6, minor=1, total_memory=8110MB, multi_processor_count=20)


In [3]:
batch_size = 256
max_length = 100

train_dataset = DrawingDataset(data_path="./data", split="train", max_length=max_length)
val_dataset = DrawingDataset(data_path="./data", split="valid", max_length=max_length)
test_dataset = DrawingDataset(data_path="./data", split="test", max_length=max_length)

train = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)
test = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

  0%|          | 1/345 [00:00<01:47,  3.21it/s]
  0%|          | 1/345 [00:00<00:04, 83.16it/s]
  0%|          | 1/345 [00:00<00:04, 70.32it/s]


In [32]:
# class mambaBlock(nn.Module):
#     def __init__(self, d_model, d_state, d_conv, expand, n_layers):
#         super(mambaBlock, self).__init__()
#         self.d_model = d_model
#         self.d_state = d_state
#         self.d_conv = d_conv
#         self.expand = expand
#         self.n_layers = n_layers
#         self.layers = []

#         self.layers = nn.ModuleList([Mamba(d_model, d_state, d_conv, expand) for _ in range(n_layers)])

#     def forward(self, x):
#         for layer in self.layers:
#             x = layer(x)
#         return x

class customModel(nn.Module):
    def __init__(self, nb, no, ns, embed_dim, state_hidden):
        super(customModel, self).__init__()
        
        self.embed_dim = embed_dim
        self.state_hidden = state_hidden
        self.proj = nn.Linear(in_features=5, out_features=self.embed_dim, bias=False)
        
        self.m1 = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'm{i}') for i in range(nb)])
        self.leftm = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'l{i}') for i in range(no)])
        self.rightm = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'r{i}') for i in range(ns)])
        
        self.offset_out = nn.Linear(in_features=self.embed_dim, out_features=2, bias=False)
        self.state_out = nn.Linear(in_features=self.embed_dim, out_features=3, bias=False)
        
        initializer_cfg = None
        for layer in self.m1:
            layer.apply(partial(_init_weights, n_layer=nb, **(initializer_cfg if initializer_cfg is not None else {})))
        for layer in self.leftm:
            layer.apply(partial(_init_weights, n_layer=no, **(initializer_cfg if initializer_cfg is not None else {})))
        for layer in self.rightm:
            layer.apply(partial(_init_weights, n_layer=ns, **(initializer_cfg if initializer_cfg is not None else {})))
        
            

    def forward(self, x): # x is of shape (B, L, 5) (Batchsize, sequence length, dimension)
        x = self.proj(x)
        hidden_states, residuals = x, None
        for layer in self.m1:
            hidden_states, residuals = layer(hidden_states, residuals)
        
        left_hidden_states, left_residuals = hidden_states, residuals
        right_hidden_states, right_residuals = hidden_states, residuals

        for layer in self.leftm:
            left_hidden_states, left_residuals = layer(left_hidden_states, left_residuals)
        for layer in self.rightm:
            right_hidden_states, right_residuals = layer(right_hidden_states, right_residuals)
        
        offset_out = self.offset_out(left_hidden_states)
        state_out = self.state_out(right_hidden_states)
        return offset_out, state_out

    @torch.inference_mode()
    def generate(self, input_seq):
        prev = input_seq
        
        i = 0
        while i < 10:
            offset, state = self.forward(prev)
            pred = torch.cat((offset, state), dim=-1)
            next_seg = pred[:, -1:, :]
            prev = torch.cat((prev, next_seg), dim=1)
            i += 1
        
        return prev

    def save(self, path):
        torch.save(self.state_dict(), path)
    
    def load(self, path):
        self.load_state_dict(torch.load(path))

In [33]:
log_interval = 10
epochs = 5

batches = len(train)

model = customModel(nb=4, no=2, ns=2, embed_dim=32, state_hidden=128).to("cuda")

offset_crit = nn.MSELoss(reduction='none')
state_crit = nn.CrossEntropyLoss(ignore_index=2)

#optimizer = torch.optim.RAdam(model.parameters(), lr=5e-4)
warmup_ratio = 0.03

print(0.05 * math.sqrt(batch_size / (batches * epochs)))
optimizer = torch.optim.AdamW(model.parameters(), 
                              lr=1e-4, betas=(0.99, 0.95), eps=1e-4,
                              weight_decay=0.05 * math.sqrt(batch_size / (batches * epochs)))   
# scheduler = get_cosine_schedule_with_warmup(optimizer, 
#                                             num_warmup_steps=batches * warmup_ratio,
#                                             num_training_steps=batches)

writer = SummaryWriter('./logs')

def train_model(model, data_loader, optimizer, epoch):
    model.train()
    
    size = len(data_loader)
    
    # Total Losses
    total_loss = 0
    total_offset_loss = 0
    total_state_loss = 0
    
    # Running Losses
    running_loss = 0
    running_offset_loss = 0
    running_state_loss = 0
    
    running_correct = 0
    running_total = 0
    
    running_mse = 0
    
    start_time = time.time()
    
    for i, data in enumerate(tqdm(data_loader)):
        optimizer.zero_grad()
        
        inputs, targets = data
        inputs = inputs.to("cuda")
        targets = targets.to("cuda")
        
        offsets, states = model(inputs)
        
        # Split Target
        offset_target = targets[:, :, :2]
        state_target = targets[:, :, 2:].argmax(dim=-1)
        no_pad_mask = state_target != 2
        
        # Masked MSE Loss for offset
        offset_loss = offset_crit(offsets, offset_target)
        offset_loss_mask = offset_loss * no_pad_mask.unsqueeze(-1).float()
        
        offset_loss = offset_loss_mask.sum() / no_pad_mask.sum()
        
        # Cross Entropy Loss for State
        state_loss = state_crit(states.transpose(1, 2), state_target)
        loss = offset_loss + state_loss
        
        # Calculate other metrics (accuracy)
        with torch.no_grad():
            states_softmax = torch.nn.functional.softmax(states, dim=-1)
            states_pred = states_softmax.argmax(dim=-1)
            
            no_pad_mask = state_target.flatten() != 2

            running_correct += (states_pred.flatten()[no_pad_mask] == state_target.flatten()[no_pad_mask]).sum().item()
            running_total += states_pred.flatten()[no_pad_mask].size().numel()
            
            flat_offsets_pred = offsets.reshape(-1, 2)[no_pad_mask, :]
            flat_offset_target = targets[:, :, :2].reshape(-1, 2)[no_pad_mask, :]
            running_mse += nn.functional.mse_loss(flat_offsets_pred, flat_offset_target)
        
        # Backprop
        loss.backward(retain_graph=True)
        
        # Gradient Clipping
        for name, param in model.named_parameters():
            torch.nn.utils.clip_grad_norm_(param, max_norm=1.0)
        
        # Optimizer Steps
        optimizer.step()
        
        running_loss += loss.item()
        running_offset_loss += offset_loss.item()
        running_state_loss += state_loss.item()
        
        total_loss += loss.item()
        total_offset_loss += offset_loss.item()
        total_state_loss += state_loss.item()
        
        # Print speed, losses, and accuracy every 25 batchs
        if i % log_interval == 0 and i > 0:
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = running_loss / log_interval
            cur_offset_loss = running_offset_loss / log_interval
            cur_state_loss = running_state_loss / log_interval
            cur_accuracy = running_correct / running_total
            cur_mse = running_mse / log_interval
            tqdm.write(f'{prefix}| epoch {(epoch+1):3d} | {i:5d}/{size:5d} batches '
                  f'| ms/batch {ms_per_batch:5.2f} | '
                  f'offset_loss {cur_offset_loss:5.2f} | state_loss {cur_state_loss:5.4f} | '
                  f'accuracy {cur_accuracy:5.4f} | mse {cur_mse:5.2f}')
            time.sleep(0)
            running_loss = 0
            running_offset_loss = 0
            running_state_loss = 0
            running_correct = 0
            running_total = 0
            running_mse = 0
            start_time = time.time()
    
    return total_loss / size, total_offset_loss / size, total_state_loss / size
        
def evaluate_model(model, data_loader):
    model.eval()
    size = len(data_loader)
    
    # Running Losses
    running_loss = 0
    running_offset_loss = 0
    running_state_loss = 0
    
    # Correct
    correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in tqdm(data_loader):
            inputs, targets = data
            inputs = inputs.to("cuda")
            targets = targets.to("cuda")
            
            offsets, states = model(inputs)
            
            # Masked MSE Loss
            offset_target = targets[:, :, :2]
            state_target = targets[:, :, 2:].argmax(dim=-1)
            no_pad_mask = state_target != 2
            
            offset_loss = offset_crit(offsets, offset_target)
            offset_loss_mask = offset_loss * no_pad_mask.unsqueeze(-1).float()
            offset_loss = offset_loss_mask.sum() / no_pad_mask.sum()
            
            state_loss = state_crit(states.transpose(1, 2), state_target)
            loss = offset_loss + state_loss
            
            # Accuracy Calculation
            states_softmax = torch.nn.functional.softmax(states, dim=-1)
            states_pred = states_softmax.argmax(dim=-1)
            
            no_pad_mask = no_pad_mask.flatten()
            
            correct += (states_pred.flatten()[no_pad_mask] == state_target.flatten()[no_pad_mask]).sum().item()
            total_samples += states_pred.flatten()[no_pad_mask].size().numel()
            
            running_loss += loss.item()
            running_offset_loss += offset_loss.item()
            running_state_loss += state_loss.item()
    
    return running_loss / size, running_offset_loss / size, running_state_loss / size, correct / total_samples
            


0.04832976746641415


In [34]:
torch.autograd.set_detect_anomaly(True)
for epoch in range(epochs):
    train_loss, train_offset_loss, train_state_loss = train_model(model, train, optimizer, epoch)
    print(f"Training: Epoch: {epoch+1}, offset_loss: {train_offset_loss:5.4f}, state_loss: {train_state_loss:5.4f}")
    writer.add_scalar("Train/Loss/Epoch", train_loss, epoch)
    writer.add_scalar("Train/Offset_Loss/Epoch", train_offset_loss, epoch)
    writer.add_scalar("Train/State_Loss/Epoch", train_state_loss, epoch)
    
    val_loss, val_offset_loss, val_state_loss, val_accuracy = evaluate_model(model, val)
    print(f"Validation: Epoch: {epoch+1}, offset_loss: {val_offset_loss:5.2f}, state_loss: {val_state_loss:5.4f}, accuracy: {val_accuracy:5.4f}")
    writer.add_scalar("Train/Loss/Epoch", val_loss, epoch)
    writer.add_scalar("Train/Offset_Loss/Epoch", val_offset_loss, epoch)
    writer.add_scalar("Train/State_Loss/Epoch", val_state_loss, epoch)
    
    #scheduler.step()

  0%|          | 0/274 [00:00<?, ?it/s]

| epoch   1 |   270/  274 batches | ms/batch 99.50 | offset_loss 10756.93 | state_loss 0.9578 | accuracy 0.8107 | mse 5378.463
Training: Epoch: 1, offset_loss: 10815.6281, state_loss: 1.0415


  0%|          | 0/10 [00:00<?, ?it/s]

Validation: Epoch: 1, offset_loss: 10983.63, state_loss: 0.9482, accuracy: 0.8136


In [7]:
test_loss, test_offset_loss, test_state_loss, test_accuracy = evaluate_model(model, test)
print(f"Test: offset_loss: {test_offset_loss:5.2f}, state_loss: {test_state_loss:5.4f}, accuracy: {test_accuracy:5.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

Test: offset_loss: 10632.34, state_loss: 0.9628, accuracy: 0.8254


In [8]:
model.save('./saved/model.pth')

In [9]:
loaded_model = customModel(nb=4, no=2, ns=2, embed_dim=32, state_hidden=128).to("cuda")
loaded_model.load('./saved/model.pth')

In [10]:
test_loss, test_offset_loss, test_state_loss, test_accuracy = evaluate_model(loaded_model, test)
print(f"Test: offset_loss: {test_offset_loss:5.2f}, state_loss: {test_state_loss:5.4f}, accuracy: {test_accuracy:5.4f}")

  0%|          | 0/10 [00:00<?, ?it/s]

Test: offset_loss: 10612.51, state_loss: 0.9628, accuracy: 0.8254


In [35]:
i = (train_dataset[0][0][:2]).unsqueeze(0)
print(i.shape)
print(model.generate(i.to("cuda")))
print(train_dataset[0][0][:10])

torch.Size([1, 2, 5])
tensor([[[ 1.6000e+01, -5.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 1.9000e+01, -1.2000e+01,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 4.6390e-02, -3.8057e-01,  2.8873e-01, -8.5750e-02, -1.7887e-01],
         [ 3.2328e-02, -3.8309e-01,  6.6125e-02,  1.2527e-02, -1.7771e-02],
         [-3.0285e-02, -7.4831e-01,  6.4276e-02, -5.0272e-03, -3.5066e-02],
         [-6.7661e-02, -9.0558e-01,  9.7885e-02, -2.4203e-02, -6.5083e-02],
         [-4.4013e-02, -7.8273e-01,  1.1584e-01, -1.1298e-02, -8.4086e-02],
         [-5.8402e-02, -7.2081e-01,  1.2373e-01, -9.4407e-03, -8.1001e-02],
         [-8.3008e-02, -7.2284e-01,  1.2082e-01, -8.5030e-03, -8.2599e-02],
         [-1.0626e-01, -7.5436e-01,  1.2234e-01, -8.5693e-03, -8.5260e-02],
         [-1.2505e-01, -7.7243e-01,  1.2460e-01, -8.4539e-03, -8.8489e-02],
         [-1.3894e-01, -7.7977e-01,  1.2763e-01, -8.0513e-03, -9.1408e-02]]],
       device='cuda:0')
tensor([[  16.,   -5.,    1.,    0.,    