In [20]:
import argparse
import logging
import math
import time
import pandas as pd
import torch
import torch.nn as nn
import os
import gc
import psutil
from dictionary_corpus import Corpus, TextDataset, Vocabulary, word_tokenizer, collate_batch, tokenize
import model
from lm_argparser import lm_parser
from utils import (
    repackage_hidden,
    get_batch,
    batchify,
    save_checkpoint,
    move_to_device,
    save_val_loss_data,
    load_model,
    get_memory_usage,
    log_memory_usage,
    clear_memory
)
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import multiprocessing as mp
from simple_data import TokenDataset, get_batch_iterators
from torch.nn.functional import scaled_dot_product_attention



In [4]:
device = torch.device("cpu")


In [5]:
corpus = Corpus('/scratch2/mrenaudin/colorlessgreenRNNs/english_data')


In [6]:
eval_batch_size = 10

val_data = batchify(corpus.valid, eval_batch_size, device)
test_data = batchify(corpus.test, eval_batch_size, device)



In [10]:
corpus.train = tokenize(corpus.dictionary, os.path.join('/scratch2/mrenaudin/colorlessgreenRNNs/english_data', 'train.txt'), shuffle=True)
train_data = batchify(corpus.train, 128, device)

In [7]:
criterion = nn.CrossEntropyLoss()


In [70]:
class CBR_RNN(nn.Module):
    # goal here is to reuse CBR_RNN but with scaled dot product attention for more efficient computations.
    # Also I got rid of options such as loading pretrained embeddings, and ablating attention to simplify the code.
    # In the future if those options are needed, they can still be copy pasted from William's code as the structure hasn't changed
    def __init__(self, ntoken, ninp, nhid, nheads, dropout=0.5, device=None):
        super().__init__()
        # same layers as Timkey
        self.device = device
        self.nheads = nheads
        self.tanh = nn.Tanh()
        self.drop = nn.Dropout(dropout)
        self.score_attn = nn.Softmax(dim=-1)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.q = nn.Linear(ninp + nhid, nhid)
        self.intermediate_h = nn.Linear(nhid * 4, nhid * 4)
        self.decoder = nn.Linear(nhid, ntoken)
        self.q_norm = torch.nn.LayerNorm(nhid)
        self.int_norm = torch.nn.LayerNorm(nhid * 4)
        self.f_norm = torch.nn.LayerNorm(nhid * 3)
        self.nhid = nhid
        self.final_h = nn.Linear(nhid * 4, nhid * 3)
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=nhid, num_heads=nheads, batch_first=True
        )

        self.init_weights()

    def init_weights(self):
        """Initialize model weights for better training dynamics"""
        # General initialization
        for name, param in self.named_parameters():
            if "weight" in name:
                if "norm" in name:
                    nn.init.ones_(param)
                elif "encoder" in name:
                    nn.init.normal_(param, mean=0, std=0.01)
                elif "decoder" in name:
                    nn.init.normal_(param, mean=0, std=0.01)
                else:
                    # Standard He initialization for processing layers
                    nn.init.kaiming_normal_(param, mode="fan_in", nonlinearity="tanh")
            elif "bias" in name:
                nn.init.zeros_(param)

    def init_cache(self, observation, nheads):
        """Initialize hidden state and attention caches with better initialization strategy"""
        if len(observation.size()) > 1:
            bsz = observation.size(dim=-1)
        else:
            bsz = 1

        hidden = torch.zeros(1, bsz, self.nhid).to(self.device) 
        if nheads == 1:
            key_cache = torch.zeros(bsz, 1, 1, self.nhid).to(self.device) 
            value_cache = torch.zeros(bsz, 1, 1, self.nhid).to(self.device) 
        else:
            key_cache = torch.zeros(bsz, 1, self.nhid).to(self.device) 
            value_cache = torch.zeros(bsz, 1, self.nhid).to(self.device) 
        return hidden, key_cache, value_cache


    def update_cache(self, key_cache, value_cache, hidden, key_cache_i, value_cache_i, hidden_i, nheads):
        hidden_i = hidden_i.unsqueeze(0)
        hidden = torch.cat((hidden, hidden_i), dim=0)
        if nheads == 1:
                key_cache_i = key_cache_i.unsqueeze(1).unsqueeze(1)
                value_cache_i = value_cache_i.unsqueeze(1).unsqueeze(1)
                key_cache = torch.cat((key_cache, key_cache_i), dim=2)
                value_cache = torch.cat((value_cache, value_cache_i), dim=2)
        else:
            key_cache_i = key_cache_i.unsqueeze(1)
            value_cache_i = value_cache_i.unsqueeze(1)
            key_cache = torch.cat((key_cache, key_cache_i), dim=1)
            value_cache = torch.cat((value_cache, value_cache_i), dim=1)
            
        return key_cache, value_cache, hidden
    
    
    def attention_layer(self, query, key_cache, value_cache, nheads):
        if nheads == 1:
                query = query.unsqueeze(1)
                
                # Ensure all tensors are on the same device
                if query.device != key_cache.device:
                    key_cache = key_cache.to(query.device)
                if query.device != value_cache.device:
                    value_cache = value_cache.to(query.device)
                    
                try:
                    attn_output = scaled_dot_product_attention(
                        query, key_cache, value_cache, is_causal=False
                    )
                except Exception as e:
                    logging.error(f"Error in attention computation: {str(e)}")
                    raise
                attn = attn_output.squeeze(1).squeeze(1)
                del attn_output  # No longer needed after squeezing
                query = query.squeeze(1).squeeze(1)
        else:
            attn_output, _ = self.multihead_attn(
                query, key_cache, value_cache, is_causal=False
            )
            attn = attn_output.squeeze(1)
            del attn_output  # No longer needed after squeezing
            query = query.squeeze(1)
            
        return attn, query
    
    def intermediate_layers(self, i, emb, query, attn, hidden):
        intermediate_input = torch.cat((emb[i], query, attn, hidden[-1]), -1)
        del query, attn  
        intermediate = self.drop(
            self.tanh(self.int_norm(self.intermediate_h(intermediate_input)))
        )
        del intermediate_input  
        final_output = self.drop(self.tanh(self.f_norm(self.final_h(intermediate))))
        del intermediate  
        key_cache_i, value_cache_i, hidden_i = final_output.split(self.nhid, dim=-1)
        del final_output
        return key_cache_i, value_cache_i, hidden_i
    
    def get_query(self, emb, hidden):

        combined = torch.cat((emb, hidden[-1]), -1)
        query = self.drop(self.tanh(self.q_norm(self.q(combined))))
        del combined  # No longer needed after creating query
        query = query.unsqueeze(1)
        return query
    
    def forward(self, observation, initial_cache, nheads):
        seq_len = observation.size(0)
        print('ok')
        hidden, key_cache, value_cache = initial_cache
        print('hidden.shape', hidden.shape)
        print('key_cache.shape', key_cache.shape)
        print('value_cache.shape', value_cache.shape)
        # 1. Encode observations
        emb = self.drop(self.encoder(observation))
        del observation  # No longer needed after encoding
        
        for i in range(seq_len):
            # 2. Concatenate with previous hidden state
            
            
            
            query = self.get_query(emb[i], hidden)
            
            attn, query = self.attention_layer(query, key_cache, value_cache, nheads)

            key_cache_i, value_cache_i, hidden_i = self.intermediate_layers(i, emb, query, attn, hidden)
            
            key_cache, value_cache, hidden = self.update_cache(key_cache, value_cache, hidden, key_cache_i, value_cache_i, hidden_i, nheads)

            del key_cache_i, value_cache_i, hidden_i  # No longer needed after concatenation
            print('key_cache.shape', key_cache.shape)
            print('value_cache.shape', value_cache.shape)
            print('hidden.shape', hidden.shape)
        cache = (hidden, key_cache, value_cache)
        decoded = self.decoder(hidden[1:])

        return decoded, cache


In [71]:
model = CBR_RNN(50001, ninp=128, nhid=128, nheads=1)

In [72]:
lr=0.001
optimizer = optim.Adam(model.parameters(), lr=lr)

In [83]:
def train():
    # Turn on training mode which enables dropout
    model.train()
    total_loss = 0
    start_time = time.time()

    
    for batch, i in enumerate(range(0, train_data.size(0) - 1, 35)):
        # Get batch
        data, targets = get_batch(train_data, i, 35)
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()

        cache = model.init_cache(data, 1)
        
            # Forward pass on chunk
        output,_ = model(data, cache, 1)
        
        
        # Reshape outputs and targets
        output_flat = output.reshape(-1, output.size(-1))
        targets_flat = targets.reshape(-1)
        
        # Calculate loss
        loss = criterion(output_flat, targets_flat)
        del output, output_flat, targets_flat
      

       
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

        optimizer.step()   

        total_loss += loss.item()


In [84]:
train()

ok
hidden.shape torch.Size([1, 128, 128])
key_cache.shape torch.Size([128, 1, 1, 128])
value_cache.shape torch.Size([128, 1, 1, 128])
key_cache.shape torch.Size([128, 1, 2, 128])
value_cache.shape torch.Size([128, 1, 2, 128])
hidden.shape torch.Size([2, 128, 128])
key_cache.shape torch.Size([128, 1, 3, 128])
value_cache.shape torch.Size([128, 1, 3, 128])
hidden.shape torch.Size([3, 128, 128])
key_cache.shape torch.Size([128, 1, 4, 128])
value_cache.shape torch.Size([128, 1, 4, 128])
hidden.shape torch.Size([4, 128, 128])
key_cache.shape torch.Size([128, 1, 5, 128])
value_cache.shape torch.Size([128, 1, 5, 128])
hidden.shape torch.Size([5, 128, 128])
key_cache.shape torch.Size([128, 1, 6, 128])
value_cache.shape torch.Size([128, 1, 6, 128])
hidden.shape torch.Size([6, 128, 128])
key_cache.shape torch.Size([128, 1, 7, 128])
value_cache.shape torch.Size([128, 1, 7, 128])
hidden.shape torch.Size([7, 128, 128])
key_cache.shape torch.Size([128, 1, 8, 128])
value_cache.shape torch.Size([128, 

KeyboardInterrupt: 