In [1]:
from mamba import BitShiftMamba
from dataset import BitShiftDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
from bidirectional_mamba import BiMambaBlock, BiMambaEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
dataset = BitShiftDataset(bit_length=10, num_samples=10000)
test_set = BitShiftDataset(bit_length=10, num_samples=1000)


In [4]:
train_dataloader = DataLoader(dataset, batch_size = 32)
test_dataloader = DataLoader(test_set, batch_size = 32)

In [5]:
model = BiMambaEncoder(d_model=8,
        num_layers=2,
        d_state=4,
        d_conv=4,
        expand=2,
        dropout=0.1,
        share_ffn=False,
        share_norm=False,).to(device)
model

BiMambaEncoder(
  (token_emb): Embedding(2, 8)
  (layers): ModuleList(
    (0-1): 2 x BiMambaBlock(
      (pre_ln_f): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (mamba_f): Mamba(
        (in_proj): Linear(in_features=8, out_features=32, bias=False)
        (conv1d): Conv1d(16, 16, kernel_size=(4,), stride=(1,), padding=(3,), groups=16)
        (act): SiLU()
        (x_proj): Linear(in_features=16, out_features=9, bias=False)
        (dt_proj): Linear(in_features=1, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=8, bias=False)
      )
      (post_ln_f): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (ffn_f): Sequential(
        (0): Linear(in_features=8, out_features=32, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=32, out_features=8, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
      (pre_ln_r): LayerNorm((8,), eps=1e-05, element

In [9]:
seq, shift = next(iter(train_dataloader))
input = seq[0]
output=shift[0]



predicted = model(input.unsqueeze(0).to(device))

print(f"Input: {input}\nOutput: {output}\n")
print(f"Predicted (logits): {predicted}")

print(f"Predicted (sigmoid): {torch.sigmoid(predicted)}")
pred = (torch.sigmoid(predicted) > 0.5).long().squeeze(0)

print(f"Predicted bits:{''.join(map(str, pred.cpu().tolist()))}")

Input: tensor([1, 1, 0, 0, 1, 1, 1, 0, 0, 1])
Output: tensor([1, 0, 0, 1, 1, 1, 0, 0, 1, 1])

Predicted (logits): tensor([[-0.1679,  0.7027, -2.0501, -2.0176,  0.6059,  0.0115,  0.5256, -2.1343,
         -2.1776,  0.8366]], device='cuda:0', grad_fn=<SqueezeBackward1>)
Predicted (sigmoid): tensor([[0.4581, 0.6688, 0.1140, 0.1174, 0.6470, 0.5029, 0.6285, 0.1058, 0.1018,
         0.6977]], device='cuda:0', grad_fn=<SigmoidBackward0>)
Predicted bits:0100111001


In [10]:
loss_fn = torch.nn.BCEWithLogitsLoss()

In [11]:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [12]:
def train_one_epoch(epoch_index):

    running_loss = 0
    last_loss = 0


    for i, data in enumerate(train_dataloader):
        seqs, shifted = data
        optimizer.zero_grad()
        outputs = model(seqs.to(device))

    #    B, L, C = outputs.shape
    #    output_logits = outputs.view(B*L, C).to(device)
    #    target_flattened = shifted.view(B*L).to(device).long()


    #    loss = loss_fn(output_logits, target_flattened)
        
        loss = loss_fn(outputs.to(device), shifted.float().to(device))

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i%100 == 99:
            last_loss = running_loss/100
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.
    return last_loss


In [13]:
def train( epochs=50):

    best_vloss = 1_000_000.

    for epoch in range(epochs):
     print('EPOCH {}:'.format(epoch + 1))
 
   
     model.train(True)
     avg_loss = train_one_epoch(epoch)


     running_vloss = 0.0
    
     model.eval()

   
     with torch.no_grad():
        for i, vdata in enumerate(test_dataloader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs.to(device))
          #  B, L, C = voutputs.shape
          #  vloss = loss_fn(voutputs.view(B*L, C).to(device), vlabels.view(B*L).to(device))
            
            vloss = loss_fn(voutputs.to(device), vlabels.float().to(device))
            running_vloss += vloss

     avg_vloss = running_vloss / (i + 1)
     print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    
    

    
     if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = './checkpoints/model_epoch_{}'.format(epoch)
        torch.save(model.state_dict(), model_path)

     
    print("Training completed. Model available to use")





In [14]:
train(epochs=30)

EPOCH 1:
  batch 100 loss: 0.1786470664292574
  batch 200 loss: 0.07511755675077439
  batch 300 loss: 0.07158786043524742
LOSS train 0.07158786043524742 valid 0.07036567479372025
EPOCH 2:
  batch 100 loss: 0.07059268198907376
  batch 200 loss: 0.07122941270470619
  batch 300 loss: 0.0702647776901722
LOSS train 0.0702647776901722 valid 0.07171255350112915
EPOCH 3:
  batch 100 loss: 0.07047304898500442
  batch 200 loss: 0.07010565727949142
  batch 300 loss: 0.06995975069701671
LOSS train 0.06995975069701671 valid 0.06964144110679626
EPOCH 4:
  batch 100 loss: 0.07018087677657604
  batch 200 loss: 0.06986536145210266
  batch 300 loss: 0.07014162123203277
LOSS train 0.07014162123203277 valid 0.06949552148580551
EPOCH 5:
  batch 100 loss: 0.0697724113613367
  batch 200 loss: 0.06975810460746289
  batch 300 loss: 0.06945275023579597
LOSS train 0.06945275023579597 valid 0.06989763677120209
EPOCH 6:
  batch 100 loss: 0.0701641021296382
  batch 200 loss: 0.06965983435511588
  batch 300 loss: 0.

In [22]:
seq, shift = next(iter(train_dataloader))
input = seq[0]
output=shift[0]



predicted = model(input.unsqueeze(0).to(device))

print(f"Input: {input}\nOutput: {output}\n")
print(f"Predicted (logits): {predicted}")

print(f"Predicted (sigmoid): {torch.sigmoid(predicted)}")

pred = (torch.sigmoid(predicted) > 0.5).long().squeeze(0)

print(f"Predicted bits:{''.join(map(str, pred.cpu().tolist()))}")

Input: tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 1])
Output: tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0])

Predicted (logits): tensor([[-12.3343,  12.1840, -12.3455, -12.3638, -12.3470, -12.3511, -12.3607,
         -12.3202,  12.1719,  -0.0603]], device='cuda:0',
       grad_fn=<SqueezeBackward1>)
Predicted (sigmoid): tensor([[4.3982e-06, 9.9999e-01, 4.3493e-06, 4.2704e-06, 4.3428e-06, 4.3248e-06,
         4.2838e-06, 4.4606e-06, 9.9999e-01, 4.8492e-01]], device='cuda:0',
       grad_fn=<SigmoidBackward0>)
Predicted bits:0100000010
