In [1]:
import time
import math
from functools import partial
from tqdm.auto import tqdm
from tqdm.utils import _term_move_up

prefix = _term_move_up() + '\r'

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# from mamba import Mamba, MambaConfig
from mamba_ssm import Mamba
# from mamba_ssm.modules.mamba_simple import Block
from mamba_ssm.models.mixer_seq_simple import create_block, _init_weights

from transformers import get_cosine_schedule_with_warmup

from data import DrawingDataset
from customModel import customModel

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
print(torch.cuda.get_device_properties(0))

True
NVIDIA GeForce RTX 3050 Ti Laptop GPU
_CudaDeviceProperties(name='NVIDIA GeForce RTX 3050 Ti Laptop GPU', major=8, minor=6, total_memory=4095MB, multi_processor_count=20)


In [3]:
batch_size = 512
max_length = 100

train_dataset = DrawingDataset(data_path="./data", split="train", max_length=max_length)
val_dataset = DrawingDataset(data_path="./data", split="valid", max_length=max_length)
test_dataset = DrawingDataset(data_path="./data", split="test", max_length=max_length)

train = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)
test = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

['./data/The Eiffel Tower.npz', './data/The Great Wall of China.npz', './data/The Mona Lisa.npz', './data/aircraft carrier.npz', './data/airplane.npz', './data/alarm clock.npz', './data/ambulance.npz', './data/angel.npz', './data/animal migration.npz', './data/ant.npz', './data/anvil.npz', './data/apple.npz', './data/arm.npz', './data/asparagus.npz', './data/axe.npz', './data/backpack.npz', './data/banana.npz', './data/bandage.npz', './data/barn.npz', './data/baseball bat.npz', './data/baseball.npz', './data/basket.npz', './data/basketball.npz', './data/bat.npz', './data/bathtub.npz', './data/beach.npz', './data/bear.npz', './data/beard.npz', './data/bed.npz', './data/bee.npz', './data/belt.npz', './data/bench.npz', './data/bicycle.npz', './data/binoculars.npz', './data/bird.npz', './data/birthday cake.npz', './data/blackberry.npz', './data/blueberry.npz', './data/book.npz', './data/boomerang.npz', './data/bottlecap.npz', './data/bowtie.npz', './data/bracelet.npz', './data/brain.npz', 

  0%|          | 0/345 [00:00<?, ?it/s]

  0%|          | 1/345 [00:00<02:12,  2.60it/s]


['./data/The Eiffel Tower.npz', './data/The Great Wall of China.npz', './data/The Mona Lisa.npz', './data/aircraft carrier.npz', './data/airplane.npz', './data/alarm clock.npz', './data/ambulance.npz', './data/angel.npz', './data/animal migration.npz', './data/ant.npz', './data/anvil.npz', './data/apple.npz', './data/arm.npz', './data/asparagus.npz', './data/axe.npz', './data/backpack.npz', './data/banana.npz', './data/bandage.npz', './data/barn.npz', './data/baseball bat.npz', './data/baseball.npz', './data/basket.npz', './data/basketball.npz', './data/bat.npz', './data/bathtub.npz', './data/beach.npz', './data/bear.npz', './data/beard.npz', './data/bed.npz', './data/bee.npz', './data/belt.npz', './data/bench.npz', './data/bicycle.npz', './data/binoculars.npz', './data/bird.npz', './data/birthday cake.npz', './data/blackberry.npz', './data/blueberry.npz', './data/book.npz', './data/boomerang.npz', './data/bottlecap.npz', './data/bowtie.npz', './data/bracelet.npz', './data/brain.npz', 

  0%|          | 1/345 [00:00<00:05, 64.73it/s]


['./data/The Eiffel Tower.npz', './data/The Great Wall of China.npz', './data/The Mona Lisa.npz', './data/aircraft carrier.npz', './data/airplane.npz', './data/alarm clock.npz', './data/ambulance.npz', './data/angel.npz', './data/animal migration.npz', './data/ant.npz', './data/anvil.npz', './data/apple.npz', './data/arm.npz', './data/asparagus.npz', './data/axe.npz', './data/backpack.npz', './data/banana.npz', './data/bandage.npz', './data/barn.npz', './data/baseball bat.npz', './data/baseball.npz', './data/basket.npz', './data/basketball.npz', './data/bat.npz', './data/bathtub.npz', './data/beach.npz', './data/bear.npz', './data/beard.npz', './data/bed.npz', './data/bee.npz', './data/belt.npz', './data/bench.npz', './data/bicycle.npz', './data/binoculars.npz', './data/bird.npz', './data/birthday cake.npz', './data/blackberry.npz', './data/blueberry.npz', './data/book.npz', './data/boomerang.npz', './data/bottlecap.npz', './data/bowtie.npz', './data/bracelet.npz', './data/brain.npz', 

  0%|          | 1/345 [00:00<00:05, 66.78it/s]


In [4]:
def get_drawing_size(drawing):
    minx, miny = 0, 0
    maxx, maxy = 0, 0
    for stroke in drawing:
        minx = min(minx, stroke[0])
        miny = min(miny, stroke[1])
        maxx = max(maxx, stroke[0])
        maxy = max(maxy, stroke[1])
    return int(maxx - minx), int(maxy - miny)
for i in range(25):
    print(get_drawing_size(train_dataset[1000+i][0]))
# print(len(train_dataset[1]))

(373, 166)
(324, 220)
(173, 67)
(299, 102)
(443, 154)
(183, 79)
(341, 187)
(314, 220)
(310, 172)
(387, 231)
(190, 153)
(314, 186)
(259, 135)
(249, 122)
(238, 154)
(351, 134)
(478, 151)
(425, 197)
(179, 122)
(204, 240)
(265, 151)
(286, 127)
(139, 156)
(248, 141)
(109, 265)


In [5]:
# class mambaBlock(nn.Module):
#     def __init__(self, d_model, d_state, d_conv, expand, n_layers):
#         super(mambaBlock, self).__init__()
#         self.d_model = d_model
#         self.d_state = d_state
#         self.d_conv = d_conv
#         self.expand = expand
#         self.n_layers = n_layers
#         self.layers = []

#         self.layers = nn.ModuleList([Mamba(d_model, d_state, d_conv, expand) for _ in range(n_layers)])

#     def forward(self, x):
#         for layer in self.layers:
#             x = layer(x)
#         return x

# class customModel(nn.Module):
#     def __init__(self, nb, no, ns, embed_dim):
#         super(customModel, self).__init__()
        
#         self.embed_dim = embed_dim
#         self.proj = nn.Linear(in_features=5, out_features=self.embed_dim, bias=False)
        
#         self.m1 = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'm{i}') for i in range(nb)])
#         self.leftm = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'l{i}') for i in range(no)])
#         self.rightm = nn.ModuleList([create_block(self.embed_dim , device='cuda', layer_idx=f'r{i}') for i in range(ns)])
        
#         self.offset_hidden = nn.Linear(in_features=self.embed_dim, out_features=16)
#         self.relu = nn.ReLU()
#         self.offset_out = nn.Linear(in_features=16, out_features=2)
#         self.state_out = nn.Linear(in_features=self.embed_dim, out_features=3, bias=False)
        
#         initializer_cfg = None
#         for layer in self.m1:
#             layer.apply(partial(_init_weights, n_layer=nb, **(initializer_cfg if initializer_cfg is not None else {})))
#         for layer in self.leftm:
#             layer.apply(partial(_init_weights, n_layer=no, **(initializer_cfg if initializer_cfg is not None else {})))
#         for layer in self.rightm:
#             layer.apply(partial(_init_weights, n_layer=ns, **(initializer_cfg if initializer_cfg is not None else {})))
        
            

#     def forward(self, x): # x is of shape (B, L, 5) (Batchsize, sequence length, dimension)
#         x = self.proj(x)
#         hidden_states, residuals = x, None
#         for layer in self.m1:
#             hidden_states, residuals = layer(hidden_states, residuals)
        
#         left_hidden_states, left_residuals = hidden_states, residuals
#         right_hidden_states, right_residuals = hidden_states, residuals

#         for layer in self.leftm:
#             left_hidden_states, left_residuals = layer(left_hidden_states, left_residuals)
#         for layer in self.rightm:
#             right_hidden_states, right_residuals = layer(right_hidden_states, right_residuals)
        
#         offset_out = self.relu(self.offset_hidden(left_hidden_states))
#         offset_out = self.offset_out(offset_out)
#         state_out = self.state_out(right_hidden_states)
#         return offset_out, state_out

#     @torch.inference_mode()
#     def generate(model, input_seq):
#         prev = input_seq
        
#         i = 0
#         while i < 20:
#             offset, state = model.forward(prev)
#             indices = torch.argmax(state, dim=-1)
#             one_hot = torch.nn.functional.one_hot(indices, 3)
#             pred = torch.cat((offset, one_hot), dim=-1)
#             next_seg = pred[:, -1:, :]
#             prev = torch.cat((prev, next_seg), dim=1)
#             i += 1
        
#         return prev

#     def save(self, path):
#         torch.save(self.state_dict(), path)
    
#     def load(self, path):
#         self.load_state_dict(torch.load(path))

In [11]:
log_interval = 10
epochs = 50

batches = len(train)

model = customModel(nb=3, no=20, ns=2, embed_dim=32).to("cuda")

offset_crit = nn.MSELoss(reduction='none')
state_crit = nn.CrossEntropyLoss()

#optimizer = torch.optim.RAdam(model.parameters(), lr=5e-4)
warmup_ratio = 0.03

print(0.05 * math.sqrt(batch_size / (batches * epochs)))
optimizer = torch.optim.AdamW(model.parameters(), 
                              lr=1e-3, betas=(0.99, 0.95), eps=1e-4,
                              weight_decay=0.05 * math.sqrt(batch_size / (batches * epochs)))   
# scheduler = get_cosine_schedule_with_warmup(optimizer, 
#                                             num_warmup_steps=batches * warmup_ratio,
#                                             num_training_steps=batches)

writer = SummaryWriter('./logs')

def train_model(model, data_loader, optimizer, epoch):
    model.train()
    
    size = len(data_loader)
    
    # Total Losses
    total_loss = 0
    total_offset_loss = 0
    total_state_loss = 0
    
    # Running Losses
    running_loss = 0
    running_offset_loss = 0
    running_state_loss = 0
    
    running_correct = 0
    running_total = 0
    
    running_mse = 0
    
    start_time = time.time()
    
    for i, data in enumerate(tqdm(data_loader)):
        optimizer.zero_grad()
        
        inputs, targets = data
        inputs = inputs.to("cuda")
        targets = targets.to("cuda")
        
        offsets, states = model(inputs)
        
        # Split Target
        offset_target = targets[:, :, :2]
        state_target = targets[:, :, 2:].argmax(dim=-1)
        no_pad_mask = state_target != 2
        
        # Masked MSE Loss for offset
        offset_loss = offset_crit(offsets, offset_target)
        offset_loss_mask = offset_loss * no_pad_mask.unsqueeze(-1).float()
        
        offset_loss = offset_loss_mask.sum() / no_pad_mask.sum()
        
        # Cross Entropy Loss for State
        state_loss = state_crit(states.transpose(1, 2), state_target)
        loss = offset_loss + state_loss
        
        # Calculate other metrics (accuracy)
        with torch.no_grad():
            states_softmax = torch.nn.functional.softmax(states, dim=-1)
            states_pred = states_softmax.argmax(dim=-1)
            
            no_pad_mask = state_target.flatten() != 2

            running_correct += (states_pred.flatten()[no_pad_mask] == state_target.flatten()[no_pad_mask]).sum().item()
            running_total += states_pred.flatten()[no_pad_mask].size().numel()
            
            flat_offsets_pred = offsets.reshape(-1, 2)[no_pad_mask, :]
            flat_offset_target = targets[:, :, :2].reshape(-1, 2)[no_pad_mask, :]
            running_mse += nn.functional.mse_loss(flat_offsets_pred, flat_offset_target)
        
        # Backprop
        loss.backward(retain_graph=True)
        
        # Gradient Clipping
        for name, param in model.named_parameters():
            torch.nn.utils.clip_grad_norm_(param, max_norm=1.5)
        
        # Optimizer Steps
        optimizer.step()
        
        running_loss += loss.item()
        running_offset_loss += offset_loss.item()
        running_state_loss += state_loss.item()
        
        total_loss += loss.item()
        total_offset_loss += offset_loss.item()
        total_state_loss += state_loss.item()
        
        # Print speed, losses, and accuracy every 25 batchs
        if i % log_interval == 0 and i > 0:
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = running_loss / log_interval
            cur_offset_loss = running_offset_loss / log_interval
            cur_state_loss = running_state_loss / log_interval
            cur_accuracy = running_correct / running_total
            cur_mse = running_mse / log_interval
            tqdm.write(f'{prefix}| epoch {(epoch+1):3d} | {i:5d}/{size:5d} batches '
                  f'| ms/batch {ms_per_batch:5.2f} | '
                  f'offset_loss {cur_offset_loss:5.2f} | state_loss {cur_state_loss:5.4f} | '
                  f'accuracy {cur_accuracy:5.4f} | mse {cur_mse:5.2f}')
            time.sleep(0)
            running_loss = 0
            running_offset_loss = 0
            running_state_loss = 0
            running_correct = 0
            running_total = 0
            running_mse = 0
            start_time = time.time()
    
    return total_loss / size, total_offset_loss / size, total_state_loss / size
        
def evaluate_model(model, data_loader):
    model.eval()
    size = len(data_loader)
    
    # Running Losses
    running_loss = 0
    running_offset_loss = 0
    running_state_loss = 0
    
    # Correct
    correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in tqdm(data_loader):
            inputs, targets = data
            inputs = inputs.to("cuda")
            targets = targets.to("cuda")
            
            offsets, states = model(inputs)
            
            # Masked MSE Loss
            offset_target = targets[:, :, :2]
            state_target = targets[:, :, 2:].argmax(dim=-1)
            no_pad_mask = state_target != 2
            
            offset_loss = offset_crit(offsets, offset_target)
            offset_loss_mask = offset_loss * no_pad_mask.unsqueeze(-1).float()
            offset_loss = offset_loss_mask.sum() / no_pad_mask.sum()
            
            state_loss = state_crit(states.transpose(1, 2), state_target)
            loss = offset_loss + state_loss
            
            # Accuracy Calculation
            states_softmax = torch.nn.functional.softmax(states, dim=-1)
            states_pred = states_softmax.argmax(dim=-1)
            
            no_pad_mask = no_pad_mask.flatten()
            
            correct += (states_pred.flatten()[no_pad_mask] == state_target.flatten()[no_pad_mask]).sum().item()
            total_samples += states_pred.flatten()[no_pad_mask].size().numel()
            
            running_loss += loss.item()
            running_offset_loss += offset_loss.item()
            running_state_loss += state_loss.item()
    
    return running_loss / size, running_offset_loss / size, running_state_loss / size, correct / total_samples
            


0.014426714154678916


In [7]:
torch.autograd.set_detect_anomaly(True)
for epoch in range(epochs):
    train_loss, train_offset_loss, train_state_loss = train_model(model, train, optimizer, epoch)
    print(f"Training: Epoch: {epoch+1}, offset_loss: {train_offset_loss:5.4f}, state_loss: {train_state_loss:5.4f}")
    writer.add_scalar("Train/Loss/Epoch", train_loss, epoch)
    writer.add_scalar("Train/Offset_Loss/Epoch", train_offset_loss, epoch)
    writer.add_scalar("Train/State_Loss/Epoch", train_state_loss, epoch)
    
    val_loss, val_offset_loss, val_state_loss, val_accuracy = evaluate_model(model, val)
    print(f"Validation: Epoch: {epoch+1}, offset_loss: {val_offset_loss:5.2f}, state_loss: {val_state_loss:5.4f}, accuracy: {val_accuracy:5.4f}")
    writer.add_scalar("Train/Loss/Epoch", val_loss, epoch)
    writer.add_scalar("Train/Offset_Loss/Epoch", val_offset_loss, epoch)
    writer.add_scalar("Train/State_Loss/Epoch", val_state_loss, epoch)
    
    #scheduler.step()

  0%|          | 0/123 [00:00<?, ?it/s]

  9%|▉         | 11/123 [00:12<02:07,  1.14s/it]

| epoch   1 |    10/  123 batches | ms/batch 1283.82 | offset_loss 3858.65 | state_loss 1.1774 | accuracy 0.5716 | mse 1929.33


 17%|█▋        | 21/123 [00:26<02:20,  1.38s/it]

| epoch   1 |    20/  123 batches | ms/batch 1321.61 | offset_loss 3473.73 | state_loss 0.9734 | accuracy 0.7758 | mse 1736.87


 25%|██▌       | 31/123 [00:40<02:07,  1.39s/it]

| epoch   1 |    30/  123 batches | ms/batch 1444.73 | offset_loss 3520.33 | state_loss 0.8298 | accuracy 0.8286 | mse 1760.17


 33%|███▎      | 41/123 [00:53<01:48,  1.32s/it]

| epoch   1 |    40/  123 batches | ms/batch 1348.39 | offset_loss 3528.92 | state_loss 0.6731 | accuracy 0.8319 | mse 1764.46


 41%|████▏     | 51/123 [01:22<03:27,  2.89s/it]

| epoch   1 |    50/  123 batches | ms/batch 2880.52 | offset_loss 3511.27 | state_loss 0.5444 | accuracy 0.8361 | mse 1755.63


 50%|████▉     | 61/123 [01:40<01:52,  1.81s/it]

| epoch   1 |    60/  123 batches | ms/batch 1798.63 | offset_loss 3474.60 | state_loss 0.4473 | accuracy 0.8401 | mse 1737.30


 58%|█████▊    | 71/123 [02:16<04:28,  5.16s/it]

| epoch   1 |    70/  123 batches | ms/batch 3585.35 | offset_loss 3452.94 | state_loss 0.3980 | accuracy 0.8439 | mse 1726.47


 66%|██████▌   | 81/123 [02:53<01:22,  1.96s/it]

| epoch   1 |    80/  123 batches | ms/batch 3635.43 | offset_loss 3369.52 | state_loss 0.4166 | accuracy 0.8479 | mse 1684.76


 74%|███████▍  | 91/123 [03:04<00:37,  1.16s/it]

| epoch   1 |    90/  123 batches | ms/batch 1126.42 | offset_loss 3442.34 | state_loss 0.4351 | accuracy 0.8508 | mse 1721.17


 82%|████████▏ | 101/123 [03:29<00:54,  2.48s/it]

| epoch   1 |   100/  123 batches | ms/batch 2539.32 | offset_loss 3527.40 | state_loss 0.4131 | accuracy 0.8513 | mse 1763.70


 90%|█████████ | 111/123 [03:40<00:15,  1.26s/it]

| epoch   1 |   110/  123 batches | ms/batch 1130.04 | offset_loss 3463.26 | state_loss 0.3746 | accuracy 0.8512 | mse 1731.63


 98%|█████████▊| 121/123 [03:54<00:03,  1.52s/it]

| epoch   1 |   120/  123 batches | ms/batch 1389.21 | offset_loss 3461.48 | state_loss 0.3386 | accuracy 0.8524 | mse 1730.74


100%|██████████| 123/123 [03:58<00:00,  1.94s/it]


Training: Epoch: 1, offset_loss: 3475.6593, state_loss: 0.5761


100%|██████████| 5/5 [00:00<00:00,  5.77it/s]


Validation: Epoch: 1, offset_loss: 3399.97, state_loss: 0.3282, accuracy: 0.8523


  1%|          | 1/123 [00:08<18:05,  8.90s/it]


KeyboardInterrupt: 

In [8]:
test_loss, test_offset_loss, test_state_loss, test_accuracy = evaluate_model(model, test)
print(f"Test: offset_loss: {test_offset_loss:5.2f}, state_loss: {test_state_loss:5.4f}, accuracy: {test_accuracy:5.4f}")

100%|██████████| 5/5 [00:10<00:00,  2.20s/it]

Test: offset_loss: 3461.92, state_loss: 0.3268, accuracy: 0.8522





In [9]:
model.save('./saved/model1.pth')

In [None]:
loaded_model = customModel(nb=3, no=20, ns=2, embed_dim=32).to("cuda")
loaded_model.load('./saved/model.pth')

In [None]:
test_loss, test_offset_loss, test_state_loss, test_accuracy = evaluate_model(loaded_model, test)
print(f"Test: offset_loss: {test_offset_loss:5.2f}, state_loss: {test_state_loss:5.4f}, accuracy: {test_accuracy:5.4f}")

In [None]:
from utils import draw_strokes

@torch.inference_mode()
def generate(model, input_seq):
    prev = input_seq
    
    i = 0
    while i < 20:
        offset, state = model.forward(prev)
        indices = torch.argmax(state, dim=-1)
        one_hot = torch.nn.functional.one_hot(indices, 3)
        pred = torch.cat((offset, one_hot), dim=-1)
        next_seg = pred[:, -1:, :]
        prev = torch.cat((prev, next_seg), dim=1)
        i += 1
    
    return prev

strokes = train_dataset[10][0]
i = (strokes[:20]).unsqueeze(0)

output_pred = generate(model, i.to("cuda")).cpu().numpy()[0]
output_actual = strokes[:30].numpy()
draw_strokes(i[0], svg_filename='./input.svg')
draw_strokes(output_pred, svg_filename='./sample.svg')
draw_strokes(output_actual, svg_filename='./actual.svg')