# Scan dataset

In [45]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, RandomSampler

In [None]:

class SCANDataset(Dataset):
    def __init__(self, file_path, input_vocab=None, output_vocab=None):
        self.data = []
        self.input_vocab = input_vocab or {}
        self.output_vocab = output_vocab or {}
        self.inverse_input_vocab = None
        self.inverse_output_vocab = None
        
        # Read the file and process each line
        with open(file_path, 'r') as f:
            lines = f.read().strip().split('\n')
        
        for example in lines:
            in_start = example.find("IN:") + len("IN:")
            out_start = example.find("OUT:")

            input_seq = example[in_start:out_start].strip().split()
            output_seq = example[out_start + len("OUT:"):].strip().split()
            # input_seq = lines[i].strip().split()
            # output_seq = lines[i + 1].strip().split()
            self.data.append((input_seq, output_seq))
            
            # Build vocabularies
            for word in input_seq:
                if word not in self.input_vocab:
                    self.input_vocab[word] = len(self.input_vocab) + 1
            for action in output_seq:
                if action not in self.output_vocab:
                    self.output_vocab[action] = len(self.output_vocab) + 1
        
        # Add special tokens
        self.input_vocab['<pad>'] = 0
        self.output_vocab['<pad>'] = 0
        
        # Create inverse vocabularies for decoding
        self.inverse_input_vocab = {v: k for k, v in self.input_vocab.items()}
        self.inverse_output_vocab = {v: k for k, v in self.output_vocab.items()}
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        input_seq, output_seq = self.data[idx]
        input_ids = [self.input_vocab[word] for word in input_seq]
        output_ids = [self.output_vocab[action] for action in output_seq]
        return torch.tensor(input_ids), torch.tensor(output_ids)

def collate_fn(batch):
    inputs, outputs = zip(*batch)
    input_lengths = [len(seq) for seq in inputs]
    output_lengths = [len(seq) for seq in outputs]
    
    # Pad sequences
    padded_inputs = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0)
    padded_outputs = torch.nn.utils.rnn.pad_sequence(outputs, batch_first=True, padding_value=0)
    
    return padded_inputs, padded_outputs, torch.tensor(input_lengths), torch.tensor(output_lengths)


In [49]:
# Example usage
dataset_path = "SCAN/tasks.txt"
dataset = SCANDataset(dataset_path)
sampler = RandomSampler(dataset, replacement=True, num_samples=100000)
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn, sampler=sampler)

In [50]:
len(dataloader)

3125

In [9]:

class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, num_layers=1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
    
    def forward(self, src, src_lengths):
        # Embed the input sequence
        embedded = self.embedding(src)
        
        # Pack the embedded sequence for efficiency
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # Pass through LSTM
        packed_outputs, (hidden, cell) = self.lstm(packed_embedded)
        
        # Unpack the sequence
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        
        return outputs, (hidden, cell)

class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, hidden_dim, num_layers=1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, tgt, hidden, cell):
        # Embed the target sequence
        embedded = self.embedding(tgt.unsqueeze(1))  # [batch_size, 1, embed_dim]
        
        # Pass through LSTM
        outputs, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        
        # Generate predictions
        predictions = self.fc(outputs.squeeze(1))  # [batch_size, output_dim]
        
        return predictions, hidden, cell
    
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, src, src_lengths, tgt, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        max_tgt_len = tgt.size(1)
        tgt_vocab_size = self.decoder.embedding.num_embeddings
        
        # Tensor to store decoder outputs
        outputs = torch.zeros(batch_size, max_tgt_len, tgt_vocab_size).to(self.device)
        
        # Encode the source sequence
        _, (hidden, cell) = self.encoder(src, src_lengths)
        
        # First input to the decoder is the <sos> token
        input_token = tgt[:, 0]
        
        for t in range(1, max_tgt_len):
            # Decode one time step
            output, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t, :] = output
            
            # Decide whether to use teacher forcing
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_token = tgt[:, t] if teacher_force else top1
        
        return outputs


In [8]:
def masked_loss_fn(outputs, targets, pad_idx):
    """Calculate loss while ignoring <pad> tokens."""
    outputs = outputs.view(-1, outputs.size(-1))
    targets = targets.view(-1)
    loss = nn.CrossEntropyLoss(ignore_index=pad_idx)
    return loss(outputs, targets)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [11]:
from tqdm import tqdm
import torch.optim as optim

# Hyperparameters
INPUT_DIM = len(dataset.input_vocab)
OUTPUT_DIM = len(dataset.output_vocab)
EMBED_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 1
PAD_IDX = dataset.output_vocab['<pad>']
NUM_EPOCHS = 1

# Model and optimizer
encoder = Encoder(INPUT_DIM, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS).to(device)
decoder = Decoder(OUTPUT_DIM, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    
    for batch in progress_bar:
        src, tgt, src_lengths, _ = batch
        src, tgt = src.to(device), tgt.to(device)
        
        optimizer.zero_grad()
        outputs = model(src, src_lengths, tgt, teacher_forcing_ratio=0.5)
        
        # Shift target sequence for loss computation
        outputs = outputs[:, 1:].contiguous()
        tgt = tgt[:, 1:].contiguous()
        
        loss = masked_loss_fn(outputs, tgt, PAD_IDX)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # Update tqdm progress bar with current loss
        progress_bar.set_postfix(loss=loss.item())
    
    print(f"Epoch {epoch+1}, Average Loss: {epoch_loss / len(dataloader):.4f}")


Epoch 1: 100%|██████████| 262/262 [00:25<00:00, 10.38it/s, loss=1.32] 

Epoch 1, Average Loss: 1.3999



