<a href="https://colab.research.google.com/github/iliaMalinovskii/MultiGrid-Transformer/blob/main/Latent_structure_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Default title text
import copy
import random
import itertools

!pip install nltk
import nltk
import os
from google.colab import drive
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer
import spacy

import datetime

import torch
import torch.nn as nn
from torch.nn import functional as F

nltk.download('wordnet')
nltk.download('punkt')
nltk.download('punkt_tab')

#...

nltk.download('averaged_perceptron_tagger', download_dir='/root/nltk_data')
nltk.data.path.append('/root/nltk_data') # Tell nltk to include the new directory in the search path

from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize, RegexpTokenizer
nltk.download('stopwords')
nltk.download('averaged_perceptron_tagger_eng')
import re

#...

torch.manual_seed(1337)

class Workflow:
  def __init__(self):
    self.log = {}
    self.alltext =""
    self.decode_dict={}
    self.pos_list =[]
    self.list_of_seq = []
    self.data = []
    self.train_data = []
    self.val_data = []
    self.vocab_sizes = []
    return

  def load_texts(self, max_text = 5, base_tokens_dict = None,
                 drive_mount_path = '/content/drive',
                 folder_path = '/content/drive/MyDrive/Colab Notebooks/fairy_tales'):

    drive.mount(drive_mount_path)
    text_count = 0
    base_token_seq = []
    for filename in os.listdir(folder_path):
      if filename.endswith(".txt"):
        filepath = os.path.join(folder_path, filename)
        try:
          with open(filepath, 'r', encoding='utf-8',errors='ignore') as f:
            file_text = f.read().lower() # TODO: remove first row, that is title
            file_text = ' '.join(file_text.strip().split())
            file_text = re.sub(r"\s+", " ", file_text).strip()
            if text_count >0: file_text = " "+file_text
            prev_t = None
            for t in file_text:
              token = base_tokens_dict.get(t)
              if token == None: continue
              if t == " " and prev_t == " ": continue
              base_token_seq.append(t)
              prev_t = t
            text_count += 1
            if (100*(text_count/max_text)) % 20 == 0: print(f"{text_count} texts loaded")
            if text_count >= max_text: break
        except Exception as e:
          print(f"Error reading file {filename}: {e}")

    self.alltext = ''.join(base_token_seq)
    return base_token_seq

  def dictionaries(self):
    list_of_dicts = {}

    # Dictionary #1 MAIN
    alphabet_value = 0
    alphabet_dict = {}
    alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789 '
    for t in alphabet:
      alphabet_dict[t] = alphabet_value
      self.decode_dict[alphabet_value] = t
      alphabet_value += 1

    list_of_dicts["MAIN"] = alphabet_dict
    # Dictionary #2
    morphemes_value = 1
    morphemes_dict = {}
    morphemes_list = ["ab", "ad", "ante", "anti", "auto", "ation", "ative",
                      "be", "bi", "circum", "co", "com", "con", "counter",
                      "de", "dis","em", "en", "epi", "es", "eu", "ex", "extra",
                      "hyper", "hypo", "ible", "il", "im", "in", "inter",
                      "intra","ion", "ir", "ise", "iso", "ition", "itive",
                      "mal","mid", "mis", "mono", "non", "ob", "omni","or",
                      "out", "over", "post", "pre", "pro", "re", "semi",
                      "sub", "super", "trans", "ty", "un", "under", "uni",
                      "vice", "ward", "with", "wise", "able", "al", "ance",
                      "ant", "ary", "ate", "dom", "ed", "ence", "ency", "er",
                      "est", "eous", "fore," "ful", "fy", "hood", "ic", "ical",
                      "ial", "ify", "ing", "ious", "ism", "ist", "ity", "ive",
                      "ize", "less", "ly", "ment", "ness", "ous", "ship",
                      "sion", "tion", "ure"]

    for m in morphemes_list:
      morphemes_dict[m] = morphemes_value
      morphemes_value += 1

    list_of_dicts["MORPH"] = morphemes_dict
    return list_of_dicts

  def pos_dict(self, default_value = 0):

    # Dictionary #3

    self.pos_list = ["CC", #coordinating conjunction
      "CD", #cardinal digit
      "DT", #determiner
      "EX", #existential there (like: “there is” … think of it like “there exists”)
      "FW", #foreign word
      "IN", #preposition/subordinating conjunction
      "JJ", # adjective – ‘big’
      "JJR", # adjective, comparative – ‘bigger’
      "JJS", # adjective, superlative – ‘biggest’
      "LS", # list marker 1)
      "MD", # modal – could, will
      "NN", # noun, singular ‘- desk’
      "NNS", # noun plural – ‘desks’
      "NNP", # proper noun, singular – ‘Harrison’
      "NNPS", # proper noun, plural – ‘Americans’
      "PDT", # predeterminer – ‘all the kids’
      "POS", # possessive ending parent’s
      "PRP", # personal pronoun –  I, he, she
      "PRP$", # possessive pronoun – my, his, hers
      "RB", # adverb – very, silently,
      "RBR", # adverb, comparative – better
      "RBS", # adverb, superlative – best
      "RP", # particle – give up
      "TO", # – to go ‘to’ the store.
      "UH", # interjection – errrrrrrrm
      "VB", # verb, base form – take
      "VBD", # verb, past tense – took
      "VBG", # verb, gerund/present participle – taking
      "VBN", # verb, past participle – taken
      "VBP", # verb, sing. present, non-3d – take
      "VBZ", # verb, 3rd person sing. present – takes
      "WDT", # wh-determiner – which
      "WP", # wh-pronoun – who, what
      "WP$", # possessive wh-pronoun, eg- whose
      "WRB"] # wh-adverb, eg- where, when

    pos_dict = {}
    pos_dict_value = 1
    for p in self.pos_list:
      pos_dict[p] = pos_dict_value
      pos_dict_value += 1

    tokenized = self.alltext.split()
    pos_seq = nltk.pos_tag(tokenized)
    pos_tokenized = []
    j=0
    for t in pos_seq:
      j+=1
      for i in range(len(t[0])):
        pos_tag_value = pos_dict.get(t[1])
        if pos_tag_value is None: pos_tag_value = default_value
        pos_tokenized.append(pos_tag_value)
      if j < len(pos_seq): pos_tokenized.append(default_value)

    return pos_tokenized

  def translate(dictionary, tokens, default_value=0):
    result = [default_value] * len(tokens)
    k=0
    for key in sorted(dictionary, key=len):
      key_len = len(key)
      if k % 50 == 0: print(k)

      k+=1
      i=0
      while i < len(tokens) - key_len + 1:
        if result[i] != default_value:
          i+=1
          continue
        subsequence = ''.join(tokens[i:i + key_len])
        if subsequence == key:
          for j in range(key_len):
            result[i + j] = dictionary[key]
          i+=j+1
        else: i+=1
    return result

  def data_prep(self, max_text = 5, default_value = 0, load_seq = None):
    n = -1
    i = 0
    for l in load_seq:
      match l:
        case "MAIN":
          self.list_of_dicts = self.dictionaries()
          self.base_token_seq = self.load_texts(max_text = max_text, base_tokens_dict = self.list_of_dicts.get("MAIN"))
          self.list_of_seq.append(Workflow.translate(self.list_of_dicts.get("MAIN"), self.base_token_seq, default_value=default_value))
          self.data.append(torch.tensor(self.list_of_seq[i], dtype=torch.long))
          self.vocab_sizes.append(len(self.list_of_dicts.get("MAIN")))
          n = int(0.9*len(self.data[i]))
          self.train_data.append(self.data[i][:n])
          self.val_data.append(self.data[i][n:])
          i+=1
        case "POS":
          pos_seq = self.pos_dict()
          self.list_of_seq.append(pos_seq)
          self.vocab_sizes.append(101)
          self.data.append(torch.tensor(self.list_of_seq[i], dtype=torch.long))
          self.train_data.append(self.data[i][:n])
          self.val_data.append(self.data[i][n:])
          i+=1
        case "MORPH":
          self.list_of_seq.append(Workflow.translate(self.list_of_dicts.get("MORPH"),self.base_token_seq,default_value = default_value))
          self.data.append(torch.tensor(self.list_of_seq[i], dtype=torch.long))
          self.vocab_sizes.append(max(list(self.list_of_dicts.get("MORPH").values()))+2)
          self.train_data.append(self.data[i][:n])
          self.val_data.append(self.data[i][n:])
          i+=1
    return

  def get_batch(self, split):
    data = self.train_data if split == 'train' else self.val_data
    ix = torch.randint(len(data[0]) - block_size-1, (batch_size,))
    x = []
    y = []
    for j in range(len(data)):
        x.append(torch.stack([data[j][i : min(i + block_size, len(data[j]))] for i in ix]))
        y.append(torch.stack([data[j][i + 1 : min(i + 1 + block_size, len(data[j]))] for i in ix]))
        x[j], y[j] = x[j].to(device), y[j].to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
  out = {}
  out_detailed = {}
  ind_losses = torch.zeros((input_dim, eval_iters))
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = w.get_batch(split)
      logits, loss, individual_losses = model(X, Y)
      losses[k] = loss.item()
      for i in range(input_dim):
        ind_losses[i][k] = individual_losses[i].item()
    out[split] = losses.mean()
    i_losses = []
    for i in range(input_dim): i_losses.append(ind_losses[i].mean())
    out_detailed[split] = i_losses #(ind_losses[0].mean(),ind_losses[1].mean(),ind_losses[2].mean())
  model.train()
  return out, out_detailed

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, T = block size ... seq length, C=embedding size)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x) # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class HeadXAttn(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, w):
        B,T,C = x.shape
        k = self.key(w)
        q = self.query(x) # (B,T,hs)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        v = self.value(w) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class MultiHeadXAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([HeadXAttn(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, w):
        out = torch.cat([h(x,w) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):

    def __init__(self, n_embd, n_head, num_frames):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head

        self.ln_t1 = nn.LayerNorm(n_embd)
        self.sa_t = MultiHeadAttention(n_head, head_size)
        self.ln_t2 = nn.LayerNorm(n_embd)
        self.ffwd_t = FeedFoward(n_embd)

        self.ln_f1 = nn.ModuleList([nn.LayerNorm(n_embd) for _ in range(num_frames)])
        self.sa_f = nn.ModuleList([MultiHeadAttention(n_head, head_size) for _ in range(num_frames)])
        self.ln_f2 = nn.ModuleList([nn.LayerNorm(n_embd) for _ in range(num_frames)])
        self.ffwd_f = nn.ModuleList([FeedFoward(n_embd) for _ in range(num_frames)])

        self.ln_x1 = nn.ModuleList([nn.LayerNorm(n_embd) for _ in range(num_frames)])
        self.ln_w1 = nn.ModuleList([nn.LayerNorm(n_embd) for _ in range(num_frames)])
        self.sa_x = nn.ModuleList([MultiHeadXAttention(n_head, head_size) for _ in range(num_frames)])
        self.ln_x2 = nn.ModuleList([nn.LayerNorm(n_embd) for _ in range(num_frames)])
        self.ffwd_x = nn.ModuleList([FeedFoward(n_embd) for _ in range(num_frames)])

        self.sa_w = nn.ModuleList([MultiHeadXAttention(n_head, head_size) for _ in range(num_frames)])
        self.ln_w2 = nn.ModuleList([nn.LayerNorm(n_embd) for _ in range(num_frames)])
        self.ffwd_w = nn.ModuleList([FeedFoward(n_embd) for _ in range(num_frames)])

    def forward(self, x, w):
        x = x + self.sa_t(self.ln_t1(x))
        x = x + self.ffwd_t(self.ln_t2(x))
        # Stack w into a single tensor
        #w = torch.stack(w, dim=0) # (num_frames, B, T, C)

        # Apply layers to the stacked w tensor
        if w is not None:
          w = w + torch.stack([sa_f(ln_f1(w_frame))
                            for sa_f, ln_f1, w_frame in zip(self.sa_f, self.ln_f1, w)], dim=0)
          w = w + torch.stack([ffwd_f(ln_f2(w_frame))
                            for ffwd_f, ln_f2, w_frame in zip(self.ffwd_f, self.ln_f2, w)], dim=0)

        # Repeat x along the num_frames dimension
        x_shape = x.shape
        if w is not None:
          x = x.unsqueeze(0).repeat(w.shape[0], 1, 1, 1) # (num_frames, B, T, C)

        # Apply cross-attention and feedforward layers
        if w is not None:
          x = x + torch.stack([sa_x(ln_x1(x_frame), ln_w1(w_frame))
                            for sa_x, ln_x1, ln_w1, x_frame, w_frame in zip(self.sa_x, self.ln_x1, self.ln_w1, x, w)], dim=0)
          x = x + torch.stack([ffwd_x(ln_x2(x_frame))
                            for ffwd_x, ln_x2, x_frame in zip(self.ffwd_x, self.ln_x2, x)], dim=0)

        if w is not None:
          w = w + torch.stack([sa_w(ln_w1(w_frame), ln_x1(x_frame))
                            for sa_w, ln_w1, ln_x1, w_frame, x_frame in zip(self.sa_w, self.ln_w1, self.ln_x1, w, x)], dim=0)
          w = w + torch.stack([ffwd_w(ln_w2(w_frame))
                            for ffwd_w, ln_w2, w_frame in zip(self.ffwd_w, self.ln_w2, w)], dim=0)

          # Average x across the num_frames dimension
          x = torch.mean(x, dim=0)

        return x, w

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_sizes):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_sizes[0], n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_sizes[0])

        self.frame_embedding_table = []
        self.frame_position_emb_table = []
        self.ln_w = []
        self.lm_head_w = []

        self.frame_embedding_table = nn.ModuleList([nn.Embedding(v, n_embd) for v in vocab_sizes[1:]])
        self.frame_position_emb_table = nn.ModuleList([nn.Embedding(block_size, n_embd) for _ in range(len(vocab_sizes) - 1)])
        self.ln_w = nn.ModuleList([nn.LayerNorm(n_embd) for _ in range(len(vocab_sizes) - 1)])
        self.lm_head_w = nn.ModuleList([nn.Linear(n_embd, v) for v in vocab_sizes[1:]])

        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head,num_frames=len(vocab_sizes)-1) for _ in range(n_layer)])
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx[0].shape # w is of the same size
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx[0]) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)

        if input_dim >1:
          w = torch.stack([emb(i) + pos(torch.arange(T, device=device))
            for emb, pos, i in zip(self.frame_embedding_table, self.frame_position_emb_table, idx[1:])], dim=0)
        else: w = None

        for block in self.blocks:
          x, w = block(x,w) # (B,T,C)

        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if len(idx)>1:
          logits_w = torch.stack([lm_head(ln(w_frame))
            for lm_head, ln, w_frame in zip(self.lm_head_w, self.ln_w, w)], dim=0)
        else: logits_w = None

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets_m = targets[0].view(B*T)
            if len(idx)>1:
              logits_w_reshaped = [lw.view(B*T, lw.shape[-1]) for lw in logits_w]
              target_w = [t.view(B*T) for t in targets[1:]]

            individual_losses = [F.cross_entropy(logits, targets_m)]

            if len(idx)>1:
              individual_losses.extend([F.cross_entropy(lw, tw) for lw, tw in zip(logits_w_reshaped, target_w)])
              loss = main_seq_loss_contribution_weight * individual_losses[0] + (1-main_seq_loss_contribution_weight) * sum(individual_losses[1:]) / len(individual_losses[1:])
            else:
              loss = individual_losses[0]

        return (logits,logits_w), loss, individual_losses

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
          idx_cond = []
          for i in range(len(idx)):
            idx_cond.append(idx[i][:, -block_size:])
            ##idx_cond.append(idx[1][:, -block_size:])
            # get the predictions
          logits, loss = self(idx_cond)
            # focus only on the last time step
          logits_m = []
          probs =[]
          idx_next_m = []
          for i in range(len(idx)):
            logits_m.append(logits[0][:, -1, :]) # becomes (B, C)
            # apply softmax to get probabilities
            probs.append(F.softmax(logits_m[i], dim=-1)) # (B, C)
            # sample from the distribution
            idx_next_m.append(torch.multinomial(probs[i], num_samples=1)) # (B, 1)
          # append sampled index to the running sequence
          idx_n = []
          for i in range(len(idx)):
            idx_n.append(torch.cat((idx[i], idx_next_m[i]), dim=1)) # (B, T+1)
          idx = idx_n
        return idx

def save_dict_to_csv_gdrive(dictionary, filename, folder_path='/content/drive/My Drive/Colab Notebooks/'):
  """Saves a dictionary to a CSV file in Google Drive.

  Args:
      dictionary: The dictionary to save.
      filename: The name of the CSV file.
      folder_path: The path to the folder in Google Drive where the file should be saved.
  """

  drive.mount('/content/drive')
  filepath = os.path.join(folder_path, filename)

  with open(filepath, 'a', newline='') as csvfile:
    writer = csv.writer(csvfile)
    for key, value in dictionary.items():
      writer.writerow([key, value])


# ------ hyperparameters ------------
batch_size = 4#64#32#16 #8 #64 # how many independent sequences will we process in parallel?
block_size = 128#128#256#128#64 #128 #256 # what is the maximum context length for predictions
max_iters =  5000#40000 # 5000 #1000 #5000
eval_interval = 250#1000 #100 #500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
eval_iters = 200
n_embd = 48#16#64#32#16# 8 #384. IMPORTANT n_embd // n_head
n_head = 16#6#4# 2 #6
n_layer = 5#7 #6
dropout = 0.2
seq_order = ['MAIN','MORPH','POS']
input_dim = len(seq_order)
max_texts = 10 #1700
main_seq_loss_contribution_weight = 0.8

# ------------------------------------

cpu


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


In [None]:
_save_needed = False

w = Workflow()

w.log["new_model_start"] = (datetime.datetime.now().year,
                            datetime.datetime.now().month,
                            datetime.datetime.now().day,
                            datetime.datetime.now().hour,
                            datetime.datetime.now().minute)

w.data_prep(max_text=max_texts, load_seq = seq_order)

w.log["memo"] = "test"
w.log["batch_size"] = batch_size
w.log["block_size"] = block_size
w.log["max_iters"] = max_iters
w.log["eval_interval"] = eval_interval
w.log["learning_rate"] = learning_rate
w.log["device"] = device
w.log["eval_iters"] = eval_iters
w.log["n_embd"] = n_embd
w.log["n_head"] = n_head
w.log["n_layer"] = n_layer
w.log["dropout"] = dropout
w.log["max_texts"] = max_texts

w.log["vocab_size"] = w.vocab_sizes

model = GPTLanguageModel(w.vocab_sizes)
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses, lossess_detailed = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        for i in range(input_dim): print(f"step {iter}: train loss {lossess_detailed['train'][i]:.4f}, val loss {lossess_detailed['val'][i]:.4f}")

    x, y = w.get_batch('train')
    (logits, logits_w), loss, individual_losses = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

idx = []
idx.append(torch.zeros((1, 1), dtype=torch.long, device=device))
idx.append(torch.zeros((1, 1), dtype=torch.long, device=device))

idx = model.generate(idx, max_new_tokens=500)
s=''
for t in idx[0][0].tolist():
  s+=str(w.decode_dict.get(t))

print(s)

w.log["sample_output"] = s

w.log["model_end"] = (datetime.datetime.now().year,
                            datetime.datetime.now().month,
                            datetime.datetime.now().day,
                            datetime.datetime.now().hour,
                            datetime.datetime.now().minute)

if _save_needed: save_dict_to_csv_gdrive(w.log, 'log_data.csv')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
2 texts loaded
4 texts loaded
6 texts loaded
8 texts loaded
10 texts loaded
0
0
50
1.026383 M parameters
step 0: train loss 3.8299, val loss 3.8290
step 0: train loss 3.6258, val loss 3.6244
step 0: train loss 4.6472, val loss 4.6478
step 0: train loss 4.6455, val loss 4.6466
step 250: train loss 2.2350, val loss 2.2206
step 250: train loss 2.4302, val loss 2.4168
step 250: train loss 1.2201, val loss 1.2111
step 250: train loss 1.6877, val loss 1.6601
step 500: train loss 2.0534, val loss 2.0485
step 500: train loss 2.2945, val loss 2.2905
step 500: train loss 1.0203, val loss 1.0318
step 500: train loss 1.1576, val loss 1.1294
step 750: train loss 1.9517, val loss 1.9380
step 750: train loss 2.1938, val loss 2.1826
step 750: train loss 0.9510, val loss 0.9384
step 750: train loss 1.0150, val loss 0.9809
step 1000: train loss 1.8051, val loss 1.8077
step 100