In [1]:
import torch
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm

In [2]:
with open("brackets.txt", "r") as f:
    brackets = f.readlines()
idxs = list(map(lambda line: list(map(lambda val: int(val), line.split())), brackets))
idxs = torch.tensor(idxs, dtype=torch.long)
idxs

tensor([[8, 2, 3,  ..., 7, 6, 9],
        [8, 0, 3,  ..., 7, 4, 9],
        [8, 0, 0,  ..., 4, 4, 9],
        ...,
        [8, 3, 0,  ..., 4, 7, 9],
        [8, 3, 3,  ..., 7, 7, 9],
        [8, 3, 0,  ..., 4, 7, 9]])

In [4]:
train_val_split = 0.8
train_size = int(train_val_split * len(idxs))

train_idxs = idxs[:train_size]
val_idxs = idxs[train_size:]

train_dataset = torch.utils.data.TensorDataset(train_idxs[:, :-1], train_idxs[:, 1:])
val_dataset = torch.utils.data.TensorDataset(val_idxs[:, :-1], val_idxs[:, 1:])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)

In [5]:
idx_to_char = {
    0: "{",
    1: "(",
    2: "[",
    3: "<",
    4: "}",
    5: ")",
    6: "]",
    7: ">",
    8: "SOS ",
    9: " EOS",
}
def decode_brackets(brackets):
    brackets = brackets.tolist()
    return "".join([idx_to_char[idx] for idx in brackets])

decode_brackets(idxs[0])

'SOS [<{[[<{{[([[[[<>]]]])]}}>]]}>] EOS'

In [6]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    for inputs, targets in data_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        loss = model.loss(outputs, targets)
        total_loss += loss.item()
    return total_loss / len(data_loader)

In [20]:
class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
        self.layernorm = nn.LayerNorm(64)
        self.layer1 = nn.Linear(64, 64)


class Model(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.layers = nn.Sequential(
            Block(),
            Block(),
            Block(),
        )
        self.output = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.layers(x)
        x = self.output(x)

        return x
    
    def loss(self, y_pred, y):
        return nn.functional.cross_entropy(y_pred.permute(0, 2, 1), y)

model = Model(10, 64).to(device)
summary(model, (1, 32), dtypes=[torch.long], device=device)

Layer (type:depth-idx)                   Output Shape              Param #
Model                                    [1, 32, 10]               --
├─Embedding: 1-1                         [1, 32, 64]               640
├─Sequential: 1-2                        [1, 64, 32]               --
│    └─GroupNorm: 2-1                    [1, 64, 32]               128
│    └─Conv1d: 2-2                       [1, 64, 32]               12,352
│    └─GELU: 2-3                         [1, 64, 32]               --
│    └─GroupNorm: 2-4                    [1, 64, 32]               128
│    └─Conv1d: 2-5                       [1, 64, 32]               12,352
│    └─GELU: 2-6                         [1, 64, 32]               --
│    └─GroupNorm: 2-7                    [1, 64, 32]               128
│    └─Conv1d: 2-8                       [1, 64, 32]               12,352
│    └─GELU: 2-9                         [1, 64, 32]               --
│    └─GroupNorm: 2-10                   [1, 64, 32]               12

In [21]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [22]:
if "val_loss" in locals():
    pass
else:
    val_loss = float("inf")

for epoch in range(100):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for inputs, targets in pbar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        outputs = model(inputs)
        loss = model.loss(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

    val_loss = evaluate(model, val_loader)
    generated = generate(model, [8])
    print(decode_brackets(generated[0]))

Epoch 0: 100%|██████████| 32/32 [00:00<00:00, 47.48it/s, loss=1.72, val_loss=6.22]


SOS SOS 


Epoch 1: 100%|██████████| 32/32 [00:00<00:00, 57.58it/s, loss=1.87, val_loss=1.7]


SOS (


Epoch 2: 100%|██████████| 32/32 [00:00<00:00, 69.28it/s, loss=1.54, val_loss=1.81]


SOS {


Epoch 3: 100%|██████████| 32/32 [00:00<00:00, 59.86it/s, loss=1.42, val_loss=1.52]


SOS  EOS


Epoch 4: 100%|██████████| 32/32 [00:00<00:00, 70.24it/s, loss=1.49, val_loss=1.41]


SOS SOS 


Epoch 5: 100%|██████████| 32/32 [00:00<00:00, 70.27it/s, loss=1.48, val_loss=1.46]


SOS <


Epoch 6: 100%|██████████| 32/32 [00:00<00:00, 58.59it/s, loss=1.49, val_loss=1.44]


SOS <


Epoch 7: 100%|██████████| 32/32 [00:00<00:00, 70.61it/s, loss=1.86, val_loss=1.53]


SOS (


Epoch 8: 100%|██████████| 32/32 [00:00<00:00, 70.96it/s, loss=1.74, val_loss=1.88]


SOS <


Epoch 9: 100%|██████████| 32/32 [00:00<00:00, 62.93it/s, loss=1.6, val_loss=1.41] 


SOS <


Epoch 10: 100%|██████████| 32/32 [00:00<00:00, 56.38it/s, loss=1.9, val_loss=2.12] 


SOS <


Epoch 11: 100%|██████████| 32/32 [00:00<00:00, 69.84it/s, loss=1.5, val_loss=1.86] 


SOS (


Epoch 12: 100%|██████████| 32/32 [00:00<00:00, 68.52it/s, loss=1.49, val_loss=1.49]


SOS <


Epoch 13: 100%|██████████| 32/32 [00:00<00:00, 55.80it/s, loss=1.53, val_loss=1.49]


SOS <


Epoch 14: 100%|██████████| 32/32 [00:00<00:00, 63.03it/s, loss=2.29, val_loss=1.5]


SOS (


Epoch 15: 100%|██████████| 32/32 [00:00<00:00, 68.53it/s, loss=1.81, val_loss=1.64]


SOS SOS 


Epoch 16: 100%|██████████| 32/32 [00:00<00:00, 60.02it/s, loss=1.61, val_loss=1.76]


SOS SOS 


Epoch 17: 100%|██████████| 32/32 [00:00<00:00, 69.49it/s, loss=1.86, val_loss=1.87]


SOS SOS 


Epoch 18: 100%|██████████| 32/32 [00:00<00:00, 70.73it/s, loss=1.59, val_loss=1.76]


SOS SOS 


Epoch 19: 100%|██████████| 32/32 [00:00<00:00, 68.68it/s, loss=1.57, val_loss=1.62]


SOS SOS 


Epoch 20:  22%|██▏       | 7/32 [00:00<00:00, 43.64it/s, loss=1.77, val_loss=1.61]


KeyboardInterrupt: 

In [18]:
@torch.no_grad()
def generate(model, start=[8, 2, 3, 0, 2, 2, 3, 0, 0, 2, 1, 2, 2, 2, 2, 3], max_len=32):
    model.eval()
    start = torch.tensor(start, dtype=torch.long).unsqueeze(0).to(device)
    output = start
    for _ in range(max_len):
        outputs = model(output)

        outputs = outputs[:, -1, :].softmax(1)

        predicted = torch.multinomial(outputs, 1).squeeze(1)

        output = torch.cat((output, predicted.unsqueeze(0)), 1)

        break
    return output

generated = generate(model, start=[8])
decode_brackets(generated[0])

'SOS <'

In [133]:
start

[8, 2, 3, 0, 2, 2, 3, 0, 0, 2, 1, 2, 2, 2, 2, 3]