In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd
from collections import Counter
import plotly.graph_objects as go
import plotly.express as px

import math
import random
import inspect
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.amp import autocast, GradScaler

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

import wandb

In [None]:
vocab = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', '&', '*']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
padding_token_index = 12
end_token_index = 11

# create a mapping from chars to ints
stoi = {ch:i for i, ch in enumerate(vocab)}
itos = {i:ch for i, ch in enumerate(vocab)}
encode = lambda s:[stoi[c] for c in s] # encoder: take a string, output a list of ints
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of ints, output a string

print(encode("12=3&"))
print(decode(encode("12=3&")))

[1, 2, 10, 3, 11]
12=3&


In [None]:
batch_size = 1000 # how many independent sequences will we process in parallel?
block_size = 60 # what is the maximum context length for predictions?
max_iters = 5000 # CHANGE the step size
eval_interval = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.0
bias = True
vocab_size = len(vocab)

In [None]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias=False): # class constructor
        super().__init__()
        # nn.Parameter, pytorch optimize will update the value of this parameter during training
        self.weight = nn.Parameter(torch.ones(ndim)) # trainable parameter
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None # trainable parameter

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-6)

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, dropout, block_size, bias=True):
        super().__init__()
        assert n_embd % n_head == 0, "Embedding dimension must be divisible by the number of heads."

        # Store hyperparameters
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.block_size = block_size

        # Key, Query, Value projections
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        # Output projection
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)

        # T-5 PE
        # self.rel_pos_bias = T5RelativePositionBias(block_size, n_head)

        # Regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

                # Check for Flash Attention availability
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # Causal mask for slow attention
            self.register_buffer(
                "bias",
                torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
            )

    def forward(self, x):
        B, T, C = x.size()  # Batch size, sequence length, embedding dimension

        # Compute Q, K, V
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)  # Split into Q, K, V (B, T, n_embd)

        # Reshape for multi-head attention
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, n_head, T, head_size)

        # Compute T5 relative position bias
        # self.rel_pos_bias = self.rel_pos_bias.to(device)  # Move to correct device
        # rel_bias = self.rel_pos_bias(T, device)  # Compute relative position bias
        # (1, num_heads, T, T)

        # Flash Attention or fallback to manual implementation
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0,
                is_causal=True
            )
        # else:
        # Manual attention with causal masking
        # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))  # Scaled dot product
        # # att = att + rel_bias  # Apply relative positional bias
        # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))  # Apply causal mask
        # att = F.softmax(att, dim=-1)  # Normalize attention scores
        # att = self.attn_dropout(att)
        # y = att @ v  # Apply attention weights to values (B, n_head, T, head_size)

        # Reshape back to original format
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # Reassemble heads

        # Output projection and residual dropout
        y = self.resid_dropout(self.c_proj(y))
        return y

# SwiGLU used in llama
class SwiGLUFFN(nn.Module):
    def __init__(self, n_embd: int, dropout: float = 0.0, bias: bool = False):
        super().__init__()
        d_ff = int((8/3) * n_embd)
        self.fc1 = nn.Linear(n_embd, 2 * d_ff, bias=bias)
        self.fc2 = nn.Linear(d_ff, n_embd, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_proj = self.fc1(x)
        x1, x2 = x_proj.chunk(2, dim=-1)
        swish = x1 * torch.sigmoid(x1)
        x = swish * x2
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout, block_size, bias=True):
        super().__init__()
        # LayerNorm and CausalSelfAttention with explicit parameters
        self.ln_1 = LayerNorm(n_embd, bias=bias)
        self.attn = CausalSelfAttention(n_embd, n_head, dropout, block_size, bias=bias)
        self.ln_2 = LayerNorm(n_embd, bias=bias)
        # self.mlp = MLP(n_embd, dropout, bias=bias)  # MLP with explicit parameters
        self.mlp = SwiGLUFFN(n_embd, dropout) #bias=bias)

    def forward(self, x):
        # Apply residual connection and pre-normalization
        x = x + self.attn(self.ln_1(x))  # Apply LayerNorm before attention
        x = x + self.mlp(self.ln_2(x))  # Apply LayerNorm before MLP
        return x


class GPT(nn.Module):

    def __init__(self, vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias=True):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        super().__init__()
        assert vocab_size is not None
        assert block_size is not None
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.dropout = dropout
        self.bias = bias

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, n_embd), # token embeddings
            # wpe = nn.Embedding(block_size, n_embd), # positional embeddings CHANGE, t-5 positional embedding
            drop = nn.Dropout(dropout),
            h = nn.ModuleList([Block(n_embd, n_head, dropout, block_size, bias=bias) for _ in range(n_layer)]), # a stack of n_layer blocks
            ln_f = LayerNorm(n_embd, bias=bias), # final layer norm
        ))
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) # projects the final transformer output to the vocab size

        # init all weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        # pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb)# + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        logits = self.lm_head(x)

        loss = None

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=encode("*")[0])
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            # loss = None

        return logits, loss

In [None]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=0.00001, top_k=None):
    """
    Generate a sequence of tokens given an initial sequence.

    Parameters:
        model (nn.Module): The model used for generation.
        idx (torch.Tensor or list): Initial sequence of indices (LongTensor of shape (b,t)).
        max_new_tokens (int): Number of new tokens to generate.
        temperature (float): Scaling factor for logits before softmax.
        top_k (int, optional): If specified, restricts sampling to top k tokens.

    Returns:
        torch.Tensor: The generated sequence.
    """
    #idx = idx.unsqueeze(0) if idx.dim() == 1 else idx
    #idx = torch.tensor(idx, device=model.device) if not isinstance(idx, torch.Tensor) else idx.to(model.device)
    batch_size, seq_len = idx.shape
    idx = idx.to(model.device)

    # Track which sequences are still active (not finished)
    is_active = torch.ones(batch_size, dtype=torch.bool, device=model.device)

    for _ in range(max_new_tokens):
        if not is_active.any():
            break
        # Ensure context length does not exceed model's block size
        idx_cond = idx if idx.size(1) <= model.block_size else idx[:, -model.block_size:]

        # Forward pass to get logits
        logits, _ = model(idx_cond)

        # Extract logits for the last token and apply temperature scaling
        logits = logits[:, -1, :] / temperature

        # Apply top-k filtering if necessary
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
            logits[logits < v[:, [-1]]] = -float('Inf')

        # Convert logits to probabilities
        probs = F.softmax(logits, dim=-1)

        # Sample next token
        idx_next = torch.multinomial(probs, num_samples=1)

        for i in range(batch_size):
            if is_active[i] and idx_next[i].item() == encode('&')[0]:
                is_active[i] = False  # if "&" appears, stop generating

        # Stop if all sequences have reached `end_token_index`
        if not is_active.any():
            break

        # Append sampled token to sequence
        idx = torch.cat((idx, idx_next), dim=1)

    decoded_texts = []
    for seq in idx.tolist():
        text = decode(seq)
        cut_text = text.split('&')[0]  # make sure generate tokens don't have "&", only got tokens before "&"
        decoded_texts.append(cut_text)

    return decoded_texts

In [None]:
def generate_origin_dataset(original, task, num_samples = 2000000):
    file_path = f"/content/drive/MyDrive/URPS/Data/origin_ds_{task}.txt"
    if os.path.exists(file_path):
        print(f"File {file_path} already exists.\nSkipping generation.")
        return
    if task == 'copy':
        # generate 200000 sample
        a_values = np.random.randint(1, original + 1, size=num_samples)
        strings = ["".join(np.random.choice([str(i) for i in range(10)], size=a)) for a in a_values]  # random generate strings
        target = strings
        to_write = [f"{a}={b}&" for a, b in zip(strings, target)]

        # write down
        with open(file_path, "w") as f:
            f.write("\n".join(to_write))

    print(f"{num_samples} original data for task {task} is saved in {file_path}")

In [None]:
# create 50000 OOD data, save
def generate_prompt_OOD(si_round, task, original):
    """
    Return a list of 'num_prompts' strings for task
    with 'original+si_round' digits each.
    """
    if task == 'copy':
        strings = "".join(np.random.choice([str(i) for i in range(10)], size=si_round+original))
        prompt_str = f"{str(strings)}="  # e.g. '1235455='

    return prompt_str


def gen_si_data(model, si_round, task, num_samples=100000, block_size=block_size, batch_size=batch_size): # length filtering
    output_path = f"/content/drive/MyDrive/URPS/Data/si_data_r{si_round-1}.txt"
    num_batches = (num_samples) // batch_size + 1
    print(f"Generating {si_round} si data...")
    for _ in range(num_batches):
        # generate 'batch_size' prompts of digit length (original + si_round)
        prompts = [generate_prompt_OOD(si_round, task, original=10) for _ in range(batch_size)]
        encoded_prompts = []

        for prompt_str in prompts: # iterate through all 1000 prompts
            # encode and convert prompt_str into tensor
            prompt_ids = encode(prompt_str)
            encoded_prompts.append(prompt_ids)  # Add encoded prompt to the list

        prompt_tensor = torch.tensor(encoded_prompts, dtype=torch.long, device=device)
        out_str = generate(
            model=model,
            idx=prompt_tensor,
            max_new_tokens=35,
            top_k=1
        )

        # length filter
        out_str = [text for text in out_str if len(text[(si_round+11):]) == (si_round + 10)]

        # print(len(out_str[0]))
        # print(out_str)
        # check number of lines in this file
        if os.path.exists(output_path):
            with open(output_path, "r", encoding="utf-8") as f:
                current_lines = sum(1 for _ in f)
        else:
            current_lines = 0

        # If we already have 50,000 lines, stop
        if current_lines >= 50000:
            print(f"Already reached 50,000 lines. Stopping early.")
            break

        # calculate remaining lines
        remaining = max(0, 50000 - current_lines)  # Prevent negative values
        to_write = out_str[:remaining]  # Only write needed amount


        # append write down
        with open(output_path, "a", encoding="utf-8") as f:
            f.writelines([line + "&\n" for line in to_write])

        # if 50000 rows, break
        # if len(to_write) < batch_size:
        #     break

    print(f"Writing complete. ")

In [None]:
def get_batch(data, batch_size=batch_size, block_size=block_size):
    """data is combined dataset, get combined dataset in train loop"""
    final_sample = random.sample(data, batch_size)
    final_sample = [line.strip() for line in final_sample]

    x_list, y_list = [], []
    for x_str in final_sample:
        # print(x_str)
        x_encoded = encode(x_str)
        x_padded = x_encoded + [padding_token_index] * (block_size - len(x_encoded))
        x_list.append(torch.tensor(x_padded, dtype=torch.int64))
        y_encoded = encode(x_str)[1:]
        y_encoded.append(end_token_index)
        y_padded = y_encoded + [padding_token_index] * (block_size - len(y_encoded))
        y_list.append(torch.tensor(y_padded, dtype=torch.int64))

    x_tensor = torch.stack(x_list).to(device)
    y_tensor = torch.stack(y_list).to(device)
    return x_tensor, y_tensor

In [None]:
with open("/content/drive/MyDrive/URPS/Data/origin_ds_copy.txt", "r", encoding="utf-8") as f:
    data = f.readlines()

In [None]:
get_batch(data)[0].shape

torch.Size([1000, 60])

In [None]:
eval_iters = 100
@torch.no_grad()
def estimate_loss(data, model):
    out = {}
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch(data)
        padding_mask_x = (X != padding_token_index).long()
        logits, loss = model(X, Y)
        losses[k] = loss.item()
    out['loss'] = losses.mean()
    model.train()
    return out

In [None]:
# Helper function for multiple training models for 90%+ accuracy
def create_optimizer_and_scheduler(model, total, warm, decay):
    # AdamW
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=5e-4,              # learning rate
        betas=(0.9, 0.99),
        eps=1e-12,
        weight_decay=0.1
    )

    # LR Scheduler
    total_steps = total # CHANGE, CHECK max_iter
    warmup_steps = warm
    decay_steps = decay
    stable_steps = total_steps - warmup_steps - decay_steps

    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps  # Linear warmup 0->1
        elif step < warmup_steps + stable_steps:
            return 1.0                  # Stable
        else:
            # Cosine decay from 1->0
            decay_ratio = (step - warmup_steps - stable_steps) / decay_steps
            return 0.5 * (1 + math.cos(math.pi * decay_ratio))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    return optimizer, scheduler

In [None]:
# Helper function for accuracy printing for each model
def accuracy_print_one(model, num_digits, need_print=False):
    correct = 0
    total = 1000
    num_batches = total // batch_size

    for _ in range(num_batches):
        prompts = ["".join(np.random.choice([str(i) for i in range(10)], size=num_digits)) + "=" for _ in range(batch_size)]  # random generate strings

        context = torch.tensor([encode(inp) for inp in prompts], dtype=torch.long, device=device)

        # output in batch
        output_batch = generate(model=model, idx=context, max_new_tokens=35, top_k=1)

        targets = [p + p[:-1] for p in prompts]
        correct += sum([output == target for output, target in zip(output_batch, targets)])

        # if needed, print wrong answer
        if need_print:
            for inp, out, target in zip(prompts, output_batch, targets):
                if out != target:
                    print(f"   Input: {inp}")
                    print(f"  Output: {out}")
                    print(f"Expected: {target}")
                    print("-----------")

    acc = correct / total
    print(f"Accuracy for {num_digits} digits: {acc}")
    return acc


def get_avg_performance(model, num_digits):
    '''
    Call this function for get the accuracy for each model
    '''
    dict_acc = {}
    for num_dig in range(1, num_digits+1):
        dict_acc[num_dig] = accuracy_print_one(model, num_dig, need_print=False)
    return dict_acc

def test_accuracy_on_digits(model, digits):
    acc_list = []
    for i in range(10):
        acc_list.append(accuracy_print_one(model, digits, need_print=False))
    return sum(acc_list)/len(acc_list)

In [None]:
def set_seeds(seed=42):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
      torch.cuda.manual_seed(seed)
      torch.cuda.manual_seed_all(seed)

In [None]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mfangyua[0m ([33mfangyua-univeristy-of-michigan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
wandb.init(project="transformer_si_graphs",
           config={
            "learning_rate": 5e-4,
            "batch_size": 1024,
            "block_size": 35,
            "optimizer": "AdamW",
            "n_embd": 384,
            "n_head": 6,
            "n_layer": 6,
            "dropout": 0.0,
            "max_iter": 10000
            },
           name= "si for 10"
)

In [None]:
generate_origin_dataset(original=10, task='copy')

File /content/drive/MyDrive/URPS/Data/origin_ds_copy.txt already exists.
Skipping generation.


In [None]:
# set_seeds(seed=22)

In [None]:
# This is a base training loop for producing base model
print(f"Start run pretrain train loop with 5000 steps and 500 warm, 1000 decay")
data = []
# INITIALIZE MODEL, OPTIMIZER, SHCEDULER
model = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias=bias)
m = model.to(device)
with open("/content/drive/MyDrive/URPS/Data/origin_ds_copy.txt", "r", encoding="utf-8") as f:
    data = f.readlines()
optimizer, scheduler = create_optimizer_and_scheduler(model, 5000, 500, 1000)

# TRAINNG LOOP:
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
loss_list = []

scaler = GradScaler('cuda')
for iter in tqdm(range(5000), desc="Training Progress"):
    # sample a batch of data
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses1 = estimate_loss(data, model)['loss']
        print(f"step {iter}: loss {losses1:.4f}")
        log_dict = {"Loss": losses1}
        loss_list.append(round(losses1.item(), 4))
        wandb.log(log_dict)

    xb, yb = get_batch(data)

    # evaluate the loss
    with autocast(device_type="cuda", dtype=torch.bfloat16):
        logits1, loss1 = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)

    scaler.scale(loss1).backward()
    scaler.step(optimizer)
    scaler.update()

    scheduler.step()

print(f"Training finished for pretrain.\nEvaluating 11-digit accuracy...")

# evaluate final performance on digit addition
acc = test_accuracy_on_digits(model, 11)
print(f"Average accuracy: {acc}")
filename = f"sc_model_{0}.pt"
save_path = f"/content/drive/MyDrive/URPS/Models/{filename}"
torch.save(model.state_dict(), save_path)
print(f"Saved best model at {save_path}")

Start run pretrain train loop with 5000 steps and 500 warm, 1000 decay
10.646016 M parameters


Training Progress:   0%|          | 2/5000 [00:05<3:22:01,  2.43s/it]

step 0: loss 2.6270


Training Progress:   2%|▏         | 101/5000 [00:21<1:22:04,  1.01s/it]

step 100: loss 1.4541


Training Progress:   4%|▍         | 201/5000 [00:38<2:16:02,  1.70s/it]

step 200: loss 1.1124


Training Progress:   6%|▌         | 301/5000 [00:54<2:19:17,  1.78s/it]

step 300: loss 1.0497


Training Progress:   8%|▊         | 401/5000 [01:10<2:17:39,  1.80s/it]

step 400: loss 1.0402


Training Progress:  10%|█         | 502/5000 [01:26<1:32:55,  1.24s/it]

step 500: loss 1.0206


Training Progress:  12%|█▏        | 601/5000 [01:42<2:13:39,  1.82s/it]

step 600: loss 1.1558


Training Progress:  14%|█▍        | 701/5000 [01:58<2:08:03,  1.79s/it]

step 700: loss 1.0025


Training Progress:  16%|█▌        | 801/5000 [02:14<2:04:47,  1.78s/it]

step 800: loss 0.9967


Training Progress:  18%|█▊        | 901/5000 [02:31<2:07:12,  1.86s/it]

step 900: loss 1.0048


Training Progress:  20%|██        | 1001/5000 [02:47<2:00:00,  1.80s/it]

step 1000: loss 1.0072


Training Progress:  22%|██▏       | 1101/5000 [03:03<1:56:55,  1.80s/it]

step 1100: loss 0.9842


Training Progress:  24%|██▍       | 1201/5000 [03:20<1:34:15,  1.49s/it]

step 1200: loss 1.0022


Training Progress:  26%|██▌       | 1301/5000 [03:36<1:51:46,  1.81s/it]

step 1300: loss 0.9831


Training Progress:  28%|██▊       | 1402/5000 [03:52<1:16:38,  1.28s/it]

step 1400: loss 0.9849


Training Progress:  30%|███       | 1501/5000 [04:08<1:46:12,  1.82s/it]

step 1500: loss 0.9822


Training Progress:  32%|███▏      | 1601/5000 [04:25<1:44:50,  1.85s/it]

step 1600: loss 0.9805


Training Progress:  34%|███▍      | 1701/5000 [04:41<1:38:52,  1.80s/it]

step 1700: loss 0.9888


Training Progress:  36%|███▌      | 1801/5000 [04:57<1:31:35,  1.72s/it]

step 1800: loss 0.9822


Training Progress:  38%|███▊      | 1901/5000 [05:13<1:17:58,  1.51s/it]

step 1900: loss 0.9803


Training Progress:  40%|████      | 2001/5000 [05:29<1:22:08,  1.64s/it]

step 2000: loss 0.9816


Training Progress:  42%|████▏     | 2101/5000 [05:46<1:23:39,  1.73s/it]

step 2100: loss 0.9837


Training Progress:  44%|████▍     | 2201/5000 [06:02<1:19:31,  1.70s/it]

step 2200: loss 0.9792


Training Progress:  46%|████▌     | 2301/5000 [06:18<1:16:46,  1.71s/it]

step 2300: loss 0.9793


Training Progress:  48%|████▊     | 2401/5000 [06:34<1:17:29,  1.79s/it]

step 2400: loss 0.9801


Training Progress:  50%|█████     | 2501/5000 [06:50<1:16:13,  1.83s/it]

step 2500: loss 0.9797


Training Progress:  52%|█████▏    | 2601/5000 [07:06<59:20,  1.48s/it]

step 2600: loss 0.9795


Training Progress:  54%|█████▍    | 2701/5000 [07:22<1:05:24,  1.71s/it]

step 2700: loss 0.9785


Training Progress:  56%|█████▌    | 2801/5000 [07:38<46:48,  1.28s/it]

step 2800: loss 0.9796


Training Progress:  58%|█████▊    | 2901/5000 [07:54<50:07,  1.43s/it]

step 2900: loss 0.9987


Training Progress:  60%|██████    | 3001/5000 [08:10<57:59,  1.74s/it]

step 3000: loss 0.9774


Training Progress:  62%|██████▏   | 3101/5000 [08:26<36:50,  1.16s/it]

step 3100: loss 0.9782


Training Progress:  64%|██████▍   | 3201/5000 [08:42<50:29,  1.68s/it]

step 3200: loss 0.9774


Training Progress:  66%|██████▌   | 3301/5000 [08:58<29:42,  1.05s/it]

step 3300: loss 0.9805


Training Progress:  68%|██████▊   | 3401/5000 [09:14<45:30,  1.71s/it]

step 3400: loss 0.9767


Training Progress:  70%|███████   | 3502/5000 [09:30<28:28,  1.14s/it]

step 3500: loss 0.9780


Training Progress:  72%|███████▏  | 3601/5000 [09:46<34:56,  1.50s/it]

step 3600: loss 0.9780


Training Progress:  74%|███████▍  | 3701/5000 [10:02<38:52,  1.80s/it]

step 3700: loss 0.9799


Training Progress:  76%|███████▌  | 3801/5000 [10:18<35:21,  1.77s/it]

step 3800: loss 0.9778


Training Progress:  78%|███████▊  | 3901/5000 [10:34<31:15,  1.71s/it]

step 3900: loss 0.9773


Training Progress:  80%|████████  | 4001/5000 [10:50<20:28,  1.23s/it]

step 4000: loss 0.9774


Training Progress:  82%|████████▏ | 4101/5000 [11:06<25:45,  1.72s/it]

step 4100: loss 0.9776


Training Progress:  84%|████████▍ | 4201/5000 [11:22<23:23,  1.76s/it]

step 4200: loss 0.9758


Training Progress:  86%|████████▌ | 4301/5000 [11:39<18:19,  1.57s/it]

step 4300: loss 0.9767


Training Progress:  88%|████████▊ | 4401/5000 [11:55<17:59,  1.80s/it]

step 4400: loss 0.9755


Training Progress:  90%|█████████ | 4501/5000 [12:11<13:37,  1.64s/it]

step 4500: loss 0.9755


Training Progress:  92%|█████████▏| 4601/5000 [12:27<10:51,  1.63s/it]

step 4600: loss 0.9753


Training Progress:  94%|█████████▍| 4701/5000 [12:43<06:42,  1.35s/it]

step 4700: loss 0.9744


Training Progress:  96%|█████████▌| 4801/5000 [12:59<04:35,  1.38s/it]

step 4800: loss 0.9743


Training Progress:  98%|█████████▊| 4901/5000 [13:15<02:51,  1.73s/it]

step 4900: loss 0.9746


Training Progress: 100%|██████████| 5000/5000 [13:31<00:00,  6.16it/s]

step 4999: loss 0.9749
Training finished for pretrain.
Evaluating 11-digit accuracy...





Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 1.0
Average accuracy: 0.9997
Saved best model at /content/drive/MyDrive/URPS/Models/sc_model_0.pt


In [None]:
set_seeds()

In [None]:
wandb.init(project="transformer_si_graphs",
           config={
            "learning_rate": 5e-4,
            "batch_size": 1024,
            "block_size": 35,
            "optimizer": "AdamW",
            "n_embd": 384,
            "n_head": 6,
            "n_layer": 6,
            "dropout": 0.0,
            "si_iter": 1500,
            "decay": 500
            },
           name= "si for 10 rounds with length filter"
)

In [None]:
wandb.finish()

0,1
Loss,█▃▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Loss,0.97418


In [None]:
def string_majority_vote_filter(list_of_strings, vote_threshold=0.4):
    """
    Given a list of strings (e.g. predictions from multiple models for ONE prompt),
    find the most frequent string. If the top string has >= ceil(threshold * N) votes,
    return that string. Otherwise, return None.
    """
    if not list_of_strings:
        return None
    freq = {}
    for s in list_of_strings:
        freq[s] = freq.get(s, 0) + 1

    best_str, best_count = None, 0
    for text, count in freq.items():
        if count > best_count:
            best_str = text
            best_count = count

    needed_votes = math.ceil(vote_threshold * len(list_of_strings))
    if best_count >= needed_votes:
        return best_str
    else:
        return None


def gen_si_data_mv(
    models,
    si_round,
    task,
    num_samples=300000,
    batch_size=1024,
    vote_threshold=0.4,  # Lower threshold for harder rounds
    max_lines_to_write=50000
):
    """
    Generate self-improvement data using majority voting plus length filtering.
    This version generates num_samples outputs in batches, and after each batch,
    it checks how many valid outputs (i.e. those that pass majority voting and the length filter)
    have been written to file. If the target of max_lines_to_write is reached, generation stops.
    """
    output_path = f"/content/drive/MyDrive/URPS/Data/si_data_r{si_round-1}.txt"
    num_batches = (num_samples) // batch_size + 1
    print(f"Generating SI data for round {si_round} with majority voting...")

    # Clear previous file to prevent accumulation
    if os.path.exists(output_path):
        os.remove(output_path)

    for batch in range(num_batches):
        # Check current number of lines in the output file
        if os.path.exists(output_path):
            with open(output_path, "r", encoding="utf-8") as f:
                current_lines = sum(1 for _ in f)
        else:
            current_lines = 0

        # If we already have max_lines_to_write, stop early.
        if current_lines >= max_lines_to_write:
            print(f"Already reached {max_lines_to_write} lines. Stopping early.")
            break

        # 1. Generate a batch of prompts.
        prompts = [generate_prompt_OOD(si_round, task, original=10) for _ in range(batch_size)]

        # 2. Get predictions from all models.
        all_model_outputs = []
        for model in models:
            encoded = [encode(p) for p in prompts]
            prompt_tensor = torch.tensor(encoded, dtype=torch.long, device=device)
            outputs = generate(model=model, idx=prompt_tensor, max_new_tokens=35, top_k=1)
            all_model_outputs.append(outputs)

        # 3. Process each prompt: apply majority voting and then length filtering.
        valid_outputs = []
        # The length filter checks that the slice starting at index (si_round+11) has exactly (si_round+10) characters.
        for i in range(len(prompts)):
            # Gather predictions for the i-th prompt.
            predictions = [all_model_outputs[m_idx][i] for m_idx in range(len(models))]
            best_pred = string_majority_vote_filter(predictions, vote_threshold=vote_threshold)
            if best_pred: #  and len(best_pred[(si_round+11):]) == (si_round+10) # NO length filtering now
                valid_outputs.append(best_pred)

        # 4. Write valid outputs to file, ensuring we do not exceed the target.
        remaining = max_lines_to_write - current_lines
        to_write = valid_outputs[:remaining]

        if to_write:
            with open(output_path, "a", encoding="utf-8") as f:
                f.writelines([line + "&\n" for line in to_write])

        print(f"Batch {batch+1}/{num_batches}: {current_lines + len(to_write)}/{max_lines_to_write} lines written.")

    # Final check: count the total number of lines written.
    if os.path.exists(output_path):
        with open(output_path, "r", encoding="utf-8") as f:
            final_lines = sum(1 for _ in f)
    else:
        final_lines = 0

    print(f"Writing complete. Total lines written: {final_lines}")

In [None]:
# In round 1, load from base checkpoints; in later rounds, load the updated models.
for si_r in range(1, 11):
    # --- Update the list of pretrained models continuously ---
    updated_models = []
    for i in range(5):
        m = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias).to(device)
        # For round 1, load base models; for later rounds, load updated ones.
        if si_r == 1:
            ckpt = f"/content/drive/MyDrive/URPS/Models/sc_model_0.pt"
        else:
            ckpt = f"/content/drive/MyDrive/URPS/Models/pretrained_model_{i}_round_{si_r-1}.pt"
        m.load_state_dict(torch.load(ckpt, map_location=device))
        updated_models.append(m)
    models_pretrained = updated_models  # Now these are the continuously updated models

    # --- Load the main model from the previous round for training ---
    main_model = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias).to(device)
    main_ckpt = f"/content/drive/MyDrive/URPS/Models/sc_model_{si_r-1}.pt"
    main_model.load_state_dict(torch.load(main_ckpt, map_location=device))

    # --- Generate new SI data using majority voting with updated models ---
    gen_si_data_mv(
        models=models_pretrained,
        si_round=si_r,
        task='copy',
        num_samples=300000,
        batch_size=1024,
        vote_threshold=0.4,
        max_lines_to_write=50000
    )

    # --- Get combined data for training ---
    data = []
    if si_r == 1:
        with open("/content/drive/MyDrive/URPS/Data/origin_ds_copy.txt", "r", encoding="utf-8") as f:
            data = f.readlines()
        with open(f"/content/drive/MyDrive/URPS/Data/si_data_r{si_r-1}.txt", "r", encoding="utf-8") as f:
            sub_data = f.readlines()
            wrong = 0
            for i in range(len(sub_data)):
                if sub_data[i][:(si_r+10)] != sub_data[i][(si_r+10+1):(si_r+10+1+si_r+10)]:
                    wrong += 1
            print(f"This filtered file has {(wrong / len(sub_data))*100}% wrong answer.")
            data += sub_data * (39+si_r)
    else:
        with open(f"/content/drive/MyDrive/URPS/Data/{si_r-1}_round_combined_ds.txt", "r", encoding="utf-8") as f:
            data = f.readlines()
        with open(f"/content/drive/MyDrive/URPS/Data/si_data_r{si_r-1}.txt", "r", encoding="utf-8") as f:
            sub_data = f.readlines()
            wrong = 0
            for i in range(len(sub_data)):
                if sub_data[i][:(si_r+10)] != sub_data[i][(si_r+10+1):(si_r+10+1+si_r+10)]:
                    wrong += 1
            print(f"This filtered file has {(wrong / len(sub_data))*100}% wrong answer.")
            data += sub_data * (39+si_r)
    random.shuffle(data)
    print(f"This is round {si_r}, The data used for training has {len(data)/1e6} M rows")

    # --- Training the main model ---
    optimizer, scheduler = create_optimizer_and_scheduler(main_model, wandb.config["si_iter"], 0, wandb.config["decay"])
    main_model.to(device)
    print(sum(p.numel() for p in main_model.parameters())/1e6, 'M parameters')
    loss_list = []
    scaler = GradScaler('cuda')
    train_step = 0

    for iter in tqdm(range(wandb.config["si_iter"]), desc="Training Progress"):
        if iter % eval_interval == 0 or iter == wandb.config["si_iter"] - 1:
            losses = estimate_loss(data, main_model)['loss']
            print(f"step {iter}: loss {losses:.4f}")
            loss_list.append(round(losses.item(), 4))
            wandb.log({"train_loss": losses.item(), "train_step": train_step})
            train_step += 1

        xb, yb = get_batch(data)

        with autocast(device_type="cuda", dtype=torch.bfloat16):
            logits1, loss1 = main_model(xb, yb)

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss1).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

    print(f"Training finished for self-improve round {si_r}.\nEvaluating {10+si_r+1}-digit accuracy...")
    acc = test_accuracy_on_digits(main_model, 10+si_r+1)
    digit_step = 10+si_r+1
    wandb.log({"Accuracy": acc, "digit_step": digit_step})
    print(f"Average accuracy for {10+si_r+1}: {acc}")
    main_save_path = f"/content/drive/MyDrive/URPS/Models/sc_model_{si_r}.pt"
    torch.save(main_model.state_dict(), main_save_path)
    print(f"Saved best main model at {main_save_path}")

    # --- CHANGE B: Update each of the k pretrained models with the new main model's state ---
    # This ensures that for the next round, the majority voting models are continuously trained.
    for i in range(5):
        pretrained_save_path = f"/content/drive/MyDrive/URPS/Models/pretrained_model_{i}_round_{si_r}.pt"
        # Here we simply copy the main model's state. Alternatively, you could train them independently.
        torch.save(main_model.state_dict(), pretrained_save_path)
        print(f"Saved updated pretrained model {i} at {pretrained_save_path}")

    # --- Combine data for future rounds ---
    data_smaller, data_larger = [], []
    if si_r == 1:
        with open("/content/drive/MyDrive/URPS/Data/origin_ds_copy.txt", "r", encoding="utf-8") as f:
            data_larger = f.readlines()
        with open(f"/content/drive/MyDrive/URPS/Data/si_data_r{si_r-1}.txt", "r", encoding="utf-8") as f:
            data_smaller = f.readlines()
    else:
        with open(f"/content/drive/MyDrive/URPS/Data/{si_r-1}_round_combined_ds.txt", "r", encoding="utf-8") as f:
            data_larger = f.readlines()
        with open(f"/content/drive/MyDrive/URPS/Data/si_data_r{si_r-1}.txt", "r", encoding="utf-8") as f:
            data_smaller = f.readlines()
    print(f"This is round {si_r}, data larger has {len(data_larger)} rows")
    print(f"This is round {si_r}, data smaller has {len(data_smaller)} rows")

    data_new = data_larger + data_smaller
    random.shuffle(data_new)

    combined_save_path = f"/content/drive/MyDrive/URPS/Data/{si_r}_round_combined_ds.txt"
    with open(combined_save_path, "w", encoding="utf-8") as f:
        f.writelines([line if line.endswith("\n") else line + "\n" for line in data_new])
    print(f"{si_r}_round_combined_ds.txt has {len(data_new)} rows")

wandb.finish()

Generating SI data for round 1 with majority voting...
Batch 1/293: 1024/50000 lines written.
Batch 2/293: 2048/50000 lines written.
Batch 3/293: 3072/50000 lines written.
Batch 4/293: 4096/50000 lines written.
Batch 5/293: 5120/50000 lines written.
Batch 6/293: 6144/50000 lines written.
Batch 7/293: 7168/50000 lines written.
Batch 8/293: 8192/50000 lines written.
Batch 9/293: 9216/50000 lines written.
Batch 10/293: 10240/50000 lines written.
Batch 11/293: 11264/50000 lines written.
Batch 12/293: 12288/50000 lines written.
Batch 13/293: 13312/50000 lines written.
Batch 14/293: 14336/50000 lines written.
Batch 15/293: 15360/50000 lines written.
Batch 16/293: 16384/50000 lines written.
Batch 17/293: 17408/50000 lines written.
Batch 18/293: 18432/50000 lines written.
Batch 19/293: 19456/50000 lines written.
Batch 20/293: 20480/50000 lines written.
Batch 21/293: 21504/50000 lines written.
Batch 22/293: 22528/50000 lines written.
Batch 23/293: 23552/50000 lines written.
Batch 24/293: 24576/

Training Progress:   0%|          | 1/1500 [00:06<2:31:40,  6.07s/it]

step 0: loss 1.2251


Training Progress:   7%|▋         | 102/1500 [00:23<32:21,  1.39s/it]

step 100: loss 1.0062


Training Progress:  13%|█▎        | 201/1500 [00:39<38:56,  1.80s/it]

step 200: loss 1.0043


Training Progress:  20%|██        | 301/1500 [00:55<36:46,  1.84s/it]

step 300: loss 1.0034


Training Progress:  27%|██▋       | 402/1500 [01:11<21:47,  1.19s/it]

step 400: loss 1.0457


Training Progress:  33%|███▎      | 501/1500 [01:28<29:48,  1.79s/it]

step 500: loss 1.0048


Training Progress:  40%|████      | 601/1500 [01:44<26:53,  1.79s/it]

step 600: loss 1.0046


Training Progress:  47%|████▋     | 701/1500 [02:01<24:48,  1.86s/it]

step 700: loss 1.0056


Training Progress:  53%|█████▎    | 801/1500 [02:17<21:10,  1.82s/it]

step 800: loss 1.0579


Training Progress:  60%|██████    | 901/1500 [02:33<15:04,  1.51s/it]

step 900: loss 1.0052


Training Progress:  67%|██████▋   | 1001/1500 [02:50<14:57,  1.80s/it]

step 1000: loss 1.0189


Training Progress:  73%|███████▎  | 1101/1500 [03:06<11:45,  1.77s/it]

step 1100: loss 1.0145


Training Progress:  80%|████████  | 1201/1500 [03:23<09:20,  1.87s/it]

step 1200: loss 1.0027


Training Progress:  87%|████████▋ | 1301/1500 [03:39<06:04,  1.83s/it]

step 1300: loss 1.0021


Training Progress:  93%|█████████▎| 1401/1500 [03:55<03:00,  1.82s/it]

step 1400: loss 1.0016


Training Progress: 100%|██████████| 1500/1500 [04:12<00:00,  5.95it/s]

step 1499: loss 1.0014
Training finished for self-improve round 1.
Evaluating 12-digit accuracy...





Accuracy for 12 digits: 1.0
Accuracy for 12 digits: 1.0
Accuracy for 12 digits: 1.0
Accuracy for 12 digits: 0.998
Accuracy for 12 digits: 1.0
Accuracy for 12 digits: 1.0
Accuracy for 12 digits: 1.0
Accuracy for 12 digits: 0.999
Accuracy for 12 digits: 1.0
Accuracy for 12 digits: 1.0
Average accuracy for 12: 0.9997
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_1.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_1.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_1.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_1.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_1.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_1.pt
This is round 1, data larger has 2000000 rows
This is round 1, data smaller has 50000 rows
1_round_combin

Training Progress:   0%|          | 1/1500 [00:06<2:31:57,  6.08s/it]

step 0: loss 1.1787


Training Progress:   7%|▋         | 101/1500 [00:22<44:55,  1.93s/it]

step 100: loss 1.0587


Training Progress:  13%|█▎        | 201/1500 [00:38<39:18,  1.82s/it]

step 200: loss 1.0632


Training Progress:  20%|██        | 301/1500 [00:55<38:47,  1.94s/it]

step 300: loss 1.0324


Training Progress:  27%|██▋       | 401/1500 [01:12<33:43,  1.84s/it]

step 400: loss 1.0153


Training Progress:  33%|███▎      | 502/1500 [01:28<21:46,  1.31s/it]

step 500: loss 1.0147


Training Progress:  40%|████      | 601/1500 [01:44<27:01,  1.80s/it]

step 600: loss 1.0365


Training Progress:  47%|████▋     | 701/1500 [02:00<24:05,  1.81s/it]

step 700: loss 1.0170


Training Progress:  53%|█████▎    | 801/1500 [02:17<17:39,  1.52s/it]

step 800: loss 1.0136


Training Progress:  60%|██████    | 901/1500 [02:33<16:16,  1.63s/it]

step 900: loss 1.0535


Training Progress:  67%|██████▋   | 1001/1500 [02:49<15:18,  1.84s/it]

step 1000: loss 1.0726


Training Progress:  73%|███████▎  | 1101/1500 [03:06<13:07,  1.97s/it]

step 1100: loss 1.0165


Training Progress:  80%|████████  | 1201/1500 [03:22<08:39,  1.74s/it]

step 1200: loss 1.0165


Training Progress:  87%|████████▋ | 1301/1500 [03:38<05:43,  1.73s/it]

step 1300: loss 1.0144


Training Progress:  93%|█████████▎| 1401/1500 [03:55<02:14,  1.36s/it]

step 1400: loss 1.0125


Training Progress: 100%|██████████| 1500/1500 [04:11<00:00,  5.96it/s]

step 1499: loss 1.0122
Training finished for self-improve round 2.
Evaluating 13-digit accuracy...





Accuracy for 13 digits: 0.999
Accuracy for 13 digits: 0.999
Accuracy for 13 digits: 1.0
Accuracy for 13 digits: 0.999
Accuracy for 13 digits: 0.999
Accuracy for 13 digits: 0.998
Accuracy for 13 digits: 1.0
Accuracy for 13 digits: 1.0
Accuracy for 13 digits: 1.0
Accuracy for 13 digits: 0.998
Average accuracy for 13: 0.9991999999999999
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_2.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_2.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_2.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_2.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_2.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_2.pt
This is round 2, data larger has 2050000 rows
This is round 2, data smaller has 50000

Training Progress:   0%|          | 1/1500 [00:06<2:32:05,  6.09s/it]

step 0: loss 1.1449


Training Progress:   7%|▋         | 101/1500 [00:22<42:49,  1.84s/it]

step 100: loss 1.0245


Training Progress:  13%|█▎        | 201/1500 [00:38<39:13,  1.81s/it]

step 200: loss 1.0238


Training Progress:  20%|██        | 301/1500 [00:55<35:37,  1.78s/it]

step 300: loss 1.0228


Training Progress:  27%|██▋       | 401/1500 [01:11<31:42,  1.73s/it]

step 400: loss 1.0236


Training Progress:  33%|███▎      | 501/1500 [01:28<32:20,  1.94s/it]

step 500: loss 1.0233


Training Progress:  40%|████      | 601/1500 [01:44<20:44,  1.38s/it]

step 600: loss 1.0228


Training Progress:  47%|████▋     | 701/1500 [02:00<24:22,  1.83s/it]

step 700: loss 1.0231


Training Progress:  53%|█████▎    | 801/1500 [02:16<16:04,  1.38s/it]

step 800: loss 1.0291


Training Progress:  60%|██████    | 901/1500 [02:33<19:21,  1.94s/it]

step 900: loss 1.0254


Training Progress:  67%|██████▋   | 1001/1500 [02:49<14:53,  1.79s/it]

step 1000: loss 1.0847


Training Progress:  73%|███████▎  | 1101/1500 [03:06<10:53,  1.64s/it]

step 1100: loss 1.0247


Training Progress:  80%|████████  | 1201/1500 [03:22<09:44,  1.95s/it]

step 1200: loss 1.0228


Training Progress:  87%|████████▋ | 1301/1500 [03:39<06:06,  1.84s/it]

step 1300: loss 1.0656


Training Progress:  93%|█████████▎| 1401/1500 [03:55<02:58,  1.80s/it]

step 1400: loss 1.0287


Training Progress: 100%|██████████| 1500/1500 [04:12<00:00,  5.94it/s]

step 1499: loss 1.0257
Training finished for self-improve round 3.
Evaluating 14-digit accuracy...





Accuracy for 14 digits: 0.997
Accuracy for 14 digits: 0.991
Accuracy for 14 digits: 0.997
Accuracy for 14 digits: 0.994
Accuracy for 14 digits: 0.994
Accuracy for 14 digits: 0.992
Accuracy for 14 digits: 0.991
Accuracy for 14 digits: 0.994
Accuracy for 14 digits: 0.996
Accuracy for 14 digits: 0.993
Average accuracy for 14: 0.9939
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_3.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_3.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_3.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_3.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_3.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_3.pt
This is round 3, data larger has 2100000 rows
This is round 3, data smaller has 50000 row

Training Progress:   0%|          | 1/1500 [00:06<2:45:20,  6.62s/it]

step 0: loss 1.1114


Training Progress:   7%|▋         | 102/1500 [00:23<31:03,  1.33s/it]

step 100: loss 1.0384


Training Progress:  13%|█▎        | 201/1500 [00:39<39:52,  1.84s/it]

step 200: loss 1.0317


Training Progress:  20%|██        | 302/1500 [00:56<25:19,  1.27s/it]

step 300: loss 1.0792


Training Progress:  27%|██▋       | 401/1500 [01:12<33:59,  1.86s/it]

step 400: loss 1.0912


Training Progress:  33%|███▎      | 502/1500 [01:29<21:39,  1.30s/it]

step 500: loss 1.0584


Training Progress:  40%|████      | 601/1500 [01:45<21:13,  1.42s/it]

step 600: loss 1.0345


Training Progress:  47%|████▋     | 701/1500 [02:02<24:24,  1.83s/it]

step 700: loss 1.0933


Training Progress:  53%|█████▎    | 801/1500 [02:18<18:52,  1.62s/it]

step 800: loss 1.0429


Training Progress:  60%|██████    | 901/1500 [02:35<19:29,  1.95s/it]

step 900: loss 1.0698


Training Progress:  67%|██████▋   | 1001/1500 [02:51<15:07,  1.82s/it]

step 1000: loss 1.0362


Training Progress:  73%|███████▎  | 1101/1500 [03:07<12:05,  1.82s/it]

step 1100: loss 1.0408


Training Progress:  80%|████████  | 1201/1500 [03:24<09:07,  1.83s/it]

step 1200: loss 1.0345


Training Progress:  87%|████████▋ | 1301/1500 [03:40<06:03,  1.83s/it]

step 1300: loss 1.0309


Training Progress:  93%|█████████▎| 1401/1500 [03:57<02:59,  1.81s/it]

step 1400: loss 1.0302


Training Progress: 100%|██████████| 1500/1500 [04:14<00:00,  5.90it/s]

step 1499: loss 1.0303
Training finished for self-improve round 4.
Evaluating 15-digit accuracy...





Accuracy for 15 digits: 0.995
Accuracy for 15 digits: 0.994
Accuracy for 15 digits: 0.994
Accuracy for 15 digits: 0.997
Accuracy for 15 digits: 0.998
Accuracy for 15 digits: 0.996
Accuracy for 15 digits: 0.993
Accuracy for 15 digits: 0.996
Accuracy for 15 digits: 0.996
Accuracy for 15 digits: 0.999
Average accuracy for 15: 0.9958000000000002
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_4.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_4.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_4.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_4.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_4.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_4.pt
This is round 4, data larger has 2150000 rows
This is round 4, data smaller h

Training Progress:   0%|          | 1/1500 [00:06<2:45:09,  6.61s/it]

step 0: loss 1.1443


Training Progress:   7%|▋         | 102/1500 [00:22<27:57,  1.20s/it]

step 100: loss 1.0393


Training Progress:  13%|█▎        | 201/1500 [00:39<35:43,  1.65s/it]

step 200: loss 1.0408


Training Progress:  20%|██        | 302/1500 [00:55<26:02,  1.30s/it]

step 300: loss 1.0632


Training Progress:  27%|██▋       | 401/1500 [01:12<32:39,  1.78s/it]

step 400: loss 1.0433


Training Progress:  33%|███▎      | 501/1500 [01:28<30:09,  1.81s/it]

step 500: loss 1.0485


Training Progress:  40%|████      | 601/1500 [01:45<29:21,  1.96s/it]

step 600: loss 1.0469


Training Progress:  47%|████▋     | 701/1500 [02:01<24:09,  1.81s/it]

step 700: loss 1.0402


Training Progress:  53%|█████▎    | 802/1500 [02:18<14:40,  1.26s/it]

step 800: loss 1.0679


Training Progress:  60%|██████    | 901/1500 [02:34<18:05,  1.81s/it]

step 900: loss 1.0879


Training Progress:  67%|██████▋   | 1001/1500 [02:51<10:50,  1.30s/it]

step 1000: loss 1.0412


Training Progress:  73%|███████▎  | 1101/1500 [03:07<11:53,  1.79s/it]

step 1100: loss 1.0409


Training Progress:  80%|████████  | 1201/1500 [03:23<07:19,  1.47s/it]

step 1200: loss 1.0379


Training Progress:  87%|████████▋ | 1301/1500 [03:40<06:05,  1.84s/it]

step 1300: loss 1.0375


Training Progress:  93%|█████████▎| 1401/1500 [03:56<03:01,  1.83s/it]

step 1400: loss 1.0368


Training Progress: 100%|██████████| 1500/1500 [04:13<00:00,  5.92it/s]

step 1499: loss 1.0369
Training finished for self-improve round 5.
Evaluating 16-digit accuracy...





Accuracy for 16 digits: 0.997
Accuracy for 16 digits: 0.999
Accuracy for 16 digits: 0.999
Accuracy for 16 digits: 0.996
Accuracy for 16 digits: 0.995
Accuracy for 16 digits: 0.997
Accuracy for 16 digits: 0.996
Accuracy for 16 digits: 1.0
Accuracy for 16 digits: 0.997
Accuracy for 16 digits: 0.999
Average accuracy for 16: 0.9974999999999999
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_5.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_5.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_5.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_5.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_5.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_5.pt
This is round 5, data larger has 2200000 rows
This is round 5, data smaller has

Training Progress:   0%|          | 1/1500 [00:06<2:36:16,  6.26s/it]

step 0: loss 1.1620


Training Progress:   7%|▋         | 101/1500 [00:22<42:51,  1.84s/it]

step 100: loss 1.0473


Training Progress:  13%|█▎        | 201/1500 [00:39<42:29,  1.96s/it]

step 200: loss 1.0630


Training Progress:  20%|██        | 301/1500 [00:56<36:42,  1.84s/it]

step 300: loss 1.0866


Training Progress:  27%|██▋       | 401/1500 [01:12<30:39,  1.67s/it]

step 400: loss 1.0480


Training Progress:  33%|███▎      | 501/1500 [01:29<30:31,  1.83s/it]

step 500: loss 1.0453


Training Progress:  40%|████      | 601/1500 [01:45<26:46,  1.79s/it]

step 600: loss 1.0494


Training Progress:  47%|████▋     | 702/1500 [02:02<17:08,  1.29s/it]

step 700: loss 1.0455


Training Progress:  53%|█████▎    | 802/1500 [02:18<13:24,  1.15s/it]

step 800: loss 1.1020


Training Progress:  60%|██████    | 901/1500 [02:34<18:17,  1.83s/it]

step 900: loss 1.0516


Training Progress:  67%|██████▋   | 1001/1500 [02:51<15:25,  1.85s/it]

step 1000: loss 1.0864


Training Progress:  73%|███████▎  | 1101/1500 [03:07<11:58,  1.80s/it]

step 1100: loss 1.0578


Training Progress:  80%|████████  | 1201/1500 [03:24<09:49,  1.97s/it]

step 1200: loss 1.0474


Training Progress:  87%|████████▋ | 1301/1500 [03:41<04:45,  1.44s/it]

step 1300: loss 1.0443


Training Progress:  93%|█████████▎| 1401/1500 [03:57<03:01,  1.83s/it]

step 1400: loss 1.0437


Training Progress: 100%|██████████| 1500/1500 [04:13<00:00,  5.91it/s]

step 1499: loss 1.0435
Training finished for self-improve round 6.
Evaluating 17-digit accuracy...





Accuracy for 17 digits: 0.996
Accuracy for 17 digits: 0.997
Accuracy for 17 digits: 0.998
Accuracy for 17 digits: 0.994
Accuracy for 17 digits: 0.998
Accuracy for 17 digits: 1.0
Accuracy for 17 digits: 0.994
Accuracy for 17 digits: 0.997
Accuracy for 17 digits: 0.999
Accuracy for 17 digits: 0.994
Average accuracy for 17: 0.9966999999999999
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_6.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_6.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_6.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_6.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_6.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_6.pt
This is round 6, data larger has 2250000 rows
This is round 6, data smaller has

Training Progress:   0%|          | 1/1500 [00:06<2:31:36,  6.07s/it]

step 0: loss 1.1476


Training Progress:   7%|▋         | 101/1500 [00:23<45:55,  1.97s/it]

step 100: loss 1.0543


Training Progress:  13%|█▎        | 201/1500 [00:39<40:19,  1.86s/it]

step 200: loss 1.0522


Training Progress:  20%|██        | 301/1500 [00:56<36:53,  1.85s/it]

step 300: loss 1.0518


Training Progress:  27%|██▋       | 401/1500 [01:12<33:21,  1.82s/it]

step 400: loss 1.0526


Training Progress:  33%|███▎      | 501/1500 [01:29<29:01,  1.74s/it]

step 500: loss 1.0690


Training Progress:  40%|████      | 601/1500 [01:45<27:55,  1.86s/it]

step 600: loss 1.0521


Training Progress:  47%|████▋     | 701/1500 [02:02<23:21,  1.75s/it]

step 700: loss 1.0517


Training Progress:  53%|█████▎    | 801/1500 [02:18<15:52,  1.36s/it]

step 800: loss 1.0510


Training Progress:  60%|██████    | 901/1500 [02:35<18:16,  1.83s/it]

step 900: loss 1.0514


Training Progress:  67%|██████▋   | 1002/1500 [02:51<09:57,  1.20s/it]

step 1000: loss 1.0510


Training Progress:  73%|███████▎  | 1101/1500 [03:08<12:29,  1.88s/it]

step 1100: loss 1.0507


Training Progress:  80%|████████  | 1201/1500 [03:25<09:07,  1.83s/it]

step 1200: loss 1.0507


Training Progress:  87%|████████▋ | 1301/1500 [03:41<06:00,  1.81s/it]

step 1300: loss 1.0502


Training Progress:  93%|█████████▎| 1401/1500 [03:58<02:48,  1.70s/it]

step 1400: loss 1.0500


Training Progress: 100%|██████████| 1500/1500 [04:14<00:00,  5.90it/s]

step 1499: loss 1.0497
Training finished for self-improve round 7.
Evaluating 18-digit accuracy...





Accuracy for 18 digits: 0.998
Accuracy for 18 digits: 0.997
Accuracy for 18 digits: 0.996
Accuracy for 18 digits: 0.996
Accuracy for 18 digits: 0.999
Accuracy for 18 digits: 0.995
Accuracy for 18 digits: 0.997
Accuracy for 18 digits: 0.993
Accuracy for 18 digits: 0.997
Accuracy for 18 digits: 0.996
Average accuracy for 18: 0.9964000000000001
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_7.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_7.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_7.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_7.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_7.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_7.pt
This is round 7, data larger has 2300000 rows
This is round 7, data smaller h

Training Progress:   0%|          | 1/1500 [00:06<2:34:06,  6.17s/it]

step 0: loss 1.1563


Training Progress:   7%|▋         | 101/1500 [00:22<40:19,  1.73s/it]

step 100: loss 1.0591


Training Progress:  13%|█▎        | 201/1500 [00:39<43:02,  1.99s/it]

step 200: loss 1.0974


Training Progress:  20%|██        | 301/1500 [00:56<36:38,  1.83s/it]

step 300: loss 1.0609


Training Progress:  27%|██▋       | 401/1500 [01:12<33:41,  1.84s/it]

step 400: loss 1.0570


Training Progress:  33%|███▎      | 501/1500 [01:28<30:33,  1.84s/it]

step 500: loss 1.0620


Training Progress:  40%|████      | 601/1500 [01:45<27:28,  1.83s/it]

step 600: loss 1.1086


Training Progress:  47%|████▋     | 701/1500 [02:02<24:30,  1.84s/it]

step 700: loss 1.0599


Training Progress:  53%|█████▎    | 802/1500 [02:18<15:17,  1.31s/it]

step 800: loss 1.1034


Training Progress:  60%|██████    | 901/1500 [02:34<18:15,  1.83s/it]

step 900: loss 1.0902


Training Progress:  67%|██████▋   | 1001/1500 [02:51<15:18,  1.84s/it]

step 1000: loss 1.0696


Training Progress:  73%|███████▎  | 1101/1500 [03:08<13:13,  1.99s/it]

step 1100: loss 1.1108


Training Progress:  80%|████████  | 1201/1500 [03:24<09:11,  1.84s/it]

step 1200: loss 1.0846


Training Progress:  87%|████████▋ | 1301/1500 [03:41<06:05,  1.84s/it]

step 1300: loss 1.0574


Training Progress:  93%|█████████▎| 1401/1500 [03:57<03:02,  1.84s/it]

step 1400: loss 1.0564


Training Progress: 100%|██████████| 1500/1500 [04:13<00:00,  5.91it/s]

step 1499: loss 1.0562
Training finished for self-improve round 8.
Evaluating 19-digit accuracy...





Accuracy for 19 digits: 0.996
Accuracy for 19 digits: 0.997
Accuracy for 19 digits: 1.0
Accuracy for 19 digits: 0.998
Accuracy for 19 digits: 0.998
Accuracy for 19 digits: 0.999
Accuracy for 19 digits: 0.994
Accuracy for 19 digits: 0.998
Accuracy for 19 digits: 0.999
Accuracy for 19 digits: 0.996
Average accuracy for 19: 0.9974999999999999
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_8.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_8.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_8.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_8.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_8.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_8.pt
This is round 8, data larger has 2350000 rows
This is round 8, data smaller has

Training Progress:   0%|          | 1/1500 [00:06<2:35:20,  6.22s/it]

step 0: loss 1.1390


Training Progress:   7%|▋         | 101/1500 [00:22<40:44,  1.75s/it]

step 100: loss 1.1073


Training Progress:  13%|█▎        | 201/1500 [00:39<40:04,  1.85s/it]

step 200: loss 1.0654


Training Progress:  20%|██        | 301/1500 [00:55<28:08,  1.41s/it]

step 300: loss 1.1108


Training Progress:  27%|██▋       | 401/1500 [01:12<33:40,  1.84s/it]

step 400: loss 1.0656


Training Progress:  33%|███▎      | 502/1500 [01:29<21:58,  1.32s/it]

step 500: loss 1.1025


Training Progress:  40%|████      | 601/1500 [01:45<27:26,  1.83s/it]

step 600: loss 1.0725


Training Progress:  47%|████▋     | 701/1500 [02:02<21:27,  1.61s/it]

step 700: loss 1.1128


Training Progress:  53%|█████▎    | 801/1500 [02:18<20:13,  1.74s/it]

step 800: loss 1.0642


Training Progress:  60%|██████    | 901/1500 [02:35<18:41,  1.87s/it]

step 900: loss 1.0962


Training Progress:  67%|██████▋   | 1001/1500 [02:52<15:24,  1.85s/it]

step 1000: loss 1.0789


Training Progress:  73%|███████▎  | 1101/1500 [03:08<11:41,  1.76s/it]

step 1100: loss 1.0637


Training Progress:  80%|████████  | 1201/1500 [03:25<09:48,  1.97s/it]

step 1200: loss 1.0631


Training Progress:  87%|████████▋ | 1301/1500 [03:41<06:05,  1.84s/it]

step 1300: loss 1.0614


Training Progress:  93%|█████████▎| 1401/1500 [03:58<03:04,  1.86s/it]

step 1400: loss 1.0612


Training Progress: 100%|██████████| 1500/1500 [04:14<00:00,  5.88it/s]

step 1499: loss 1.0611
Training finished for self-improve round 9.
Evaluating 20-digit accuracy...





Accuracy for 20 digits: 1.0
Accuracy for 20 digits: 0.995
Accuracy for 20 digits: 0.996
Accuracy for 20 digits: 0.995
Accuracy for 20 digits: 0.995
Accuracy for 20 digits: 0.998
Accuracy for 20 digits: 0.991
Accuracy for 20 digits: 0.995
Accuracy for 20 digits: 0.994
Accuracy for 20 digits: 0.992
Average accuracy for 20: 0.9951000000000001
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_9.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_9.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_9.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_9.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_9.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_9.pt
This is round 9, data larger has 2400000 rows
This is round 9, data smaller has

Training Progress:   0%|          | 1/1500 [00:06<2:50:30,  6.82s/it]

step 0: loss 1.1706


Training Progress:   7%|▋         | 101/1500 [00:23<42:52,  1.84s/it]

step 100: loss 1.0704


Training Progress:  13%|█▎        | 202/1500 [00:39<28:24,  1.31s/it]

step 200: loss 1.0710


Training Progress:  20%|██        | 302/1500 [00:56<26:21,  1.32s/it]

step 300: loss 1.0667


Training Progress:  27%|██▋       | 402/1500 [01:12<23:58,  1.31s/it]

step 400: loss 1.1031


Training Progress:  33%|███▎      | 501/1500 [01:29<30:56,  1.86s/it]

step 500: loss 1.1152


Training Progress:  40%|████      | 601/1500 [01:46<27:51,  1.86s/it]

step 600: loss 1.0692


Training Progress:  47%|████▋     | 701/1500 [02:02<24:35,  1.85s/it]

step 700: loss 1.0682


Training Progress:  53%|█████▎    | 801/1500 [02:19<22:21,  1.92s/it]

step 800: loss 1.0883


Training Progress:  60%|██████    | 901/1500 [02:35<18:26,  1.85s/it]

step 900: loss 1.0685


Training Progress:  67%|██████▋   | 1001/1500 [02:52<15:19,  1.84s/it]

step 1000: loss 1.0955


Training Progress:  73%|███████▎  | 1101/1500 [03:08<12:16,  1.85s/it]

step 1100: loss 1.1171


Training Progress:  80%|████████  | 1201/1500 [03:25<09:08,  1.83s/it]

step 1200: loss 1.0917


Training Progress:  87%|████████▋ | 1301/1500 [03:41<05:52,  1.77s/it]

step 1300: loss 1.0688


Training Progress:  93%|█████████▎| 1401/1500 [03:58<03:02,  1.84s/it]

step 1400: loss 1.0674


Training Progress: 100%|██████████| 1500/1500 [04:15<00:00,  5.87it/s]

step 1499: loss 1.0671
Training finished for self-improve round 10.
Evaluating 21-digit accuracy...





Accuracy for 21 digits: 0.996
Accuracy for 21 digits: 0.992
Accuracy for 21 digits: 0.998
Accuracy for 21 digits: 0.994
Accuracy for 21 digits: 0.993
Accuracy for 21 digits: 0.996
Accuracy for 21 digits: 0.991
Accuracy for 21 digits: 0.994
Accuracy for 21 digits: 0.993
Accuracy for 21 digits: 0.996
Average accuracy for 21: 0.9943
Saved best main model at /content/drive/MyDrive/URPS/Models/sc_model_10.pt
Saved updated pretrained model 0 at /content/drive/MyDrive/URPS/Models/pretrained_model_0_round_10.pt
Saved updated pretrained model 1 at /content/drive/MyDrive/URPS/Models/pretrained_model_1_round_10.pt
Saved updated pretrained model 2 at /content/drive/MyDrive/URPS/Models/pretrained_model_2_round_10.pt
Saved updated pretrained model 3 at /content/drive/MyDrive/URPS/Models/pretrained_model_3_round_10.pt
Saved updated pretrained model 4 at /content/drive/MyDrive/URPS/Models/pretrained_model_4_round_10.pt
This is round 10, data larger has 2450000 rows
This is round 10, data smaller has 5

0,1
Accuracy,█▇▁▃▅▄▄▅▂▁
digit_step,▁▂▃▃▄▅▆▆▇█
train_loss,▃▁▂▁▁█▃▄▁▁▂▂▂▃▃▃▂▂▃▃▃▃▃▇▃▃▃▃▃▅▄▅▃▅▄▃█▄▅▄
train_step,▁▃▃▄▃▅▅▆▇▁▆▁▃▇▇▃▄▅▅█▁▃▃▅▇▅▇█▃▇▁▂▄▅▆▁▃▃▅▇

0,1
Accuracy,0.9943
digit_step,21.0
train_loss,1.06709
train_step,15.0


In [None]:
diff_model_performance = {}
for i in range (11):
    model = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias)
    model.to(device)
    checkpoint_path = f"/content/drive/MyDrive/URPS/Models/sc_model_{i}.pt"
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    one_list = []
    for j in range(11, 21):
        acc = test_accuracy_on_digits(model, j)
        one_list.append(acc)
    diff_model_performance[i] = one_list

Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 1.0
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 0.999
Accuracy for 11 digits: 1.0
Accuracy for 12 digits: 0.991
Accuracy for 12 digits: 0.988
Accuracy for 12 digits: 0.99
Accuracy for 12 digits: 0.991
Accuracy for 12 digits: 0.981
Accuracy for 12 digits: 0.984
Accuracy for 12 digits: 0.989
Accuracy for 12 digits: 0.984
Accuracy for 12 digits: 0.987
Accuracy for 12 digits: 0.984
Accuracy for 13 digits: 0.809
Accuracy for 13 digits: 0.814
Accuracy for 13 digits: 0.81
Accuracy for 13 digits: 0.809
Accuracy for 13 digits: 0.802
Accuracy for 13 digits: 0.831
Accuracy for 13 digits: 0.848
Accuracy for 13 digits: 0.806
Accuracy for 13 digits: 0.791
Accuracy for 13 digits: 0.835
Accuracy for 14 digits: 0.363
Accuracy for 14 digits: 0.388
Accuracy for 14 digits: 0.392
Accuracy for 14 digi

In [None]:
fig = go.Figure()

x_values = [i for i in range(11, 21)]


i = 0
for m_performace in diff_model_performance.values():
    fig.add_trace(go.Scatter(x=x_values,
                             y=m_performace,
                             mode='lines+markers',

                             name=f"Self-improvement round {i}"))
    i += 1

fig.update_layout(title="10 rounds of self-improvement, with majority voting", xaxis_title="number of digits", yaxis_title="Average Accuracy")
fig.update_layout(xaxis_title="number of digits", yaxis_title="Average Accuracy")
fig.update_yaxes(range=[-0.02, 1.02])
fig.update_xaxes(tickmode="array", tickvals=x_values)
fig.update_layout(width=1000, height=500)

fig.show()

wandb.init(project="transformer_si_graphs", name="si for 10 rounds with majority voting")
wandb.log({"Interactive Chart": wandb.Html(fig.to_html())})
wandb.finish()

In [None]:
def save_wrong_answers(si_data_file, si_round):
    """
    Reads the SI data file and saves lines where the expected answer (first si_round+10 characters)
    does not match the generated answer (the subsequent si_round+10 characters after an '=' token)
    into a wrong answers file.
    """
    wrong_answers = []
    with open(si_data_file, "r", encoding="utf-8") as f:
        lines = f.readlines()
    for line in lines:
        # Assuming the expected answer is in the first (si_round+10) characters and
        # the generated answer is in the substring starting at index (si_round+10+1)
        expected = line[:(si_round+10)]
        generated = line[(si_round+10+1):(si_round+10+1+si_round+10)]
        if expected != generated:
            wrong_answers.append(line)
    wrong_filename = f"/content/drive/MyDrive/URPS/Data/wrong_answers_round_{si_round}.txt"
    with open(wrong_filename, "w", encoding="utf-8") as f:
        f.writelines(wrong_answers)
    print(f"Round {si_round}: Saved {len(wrong_answers)} wrong answers to {wrong_filename}")
    return wrong_filename



In [None]:
# Evaluate each saved model (sc_model_0 through sc_model_10)
for r in range(0, 11):
    # -----------------------------------
    # Step 1: Load the model checkpoint.
    # -----------------------------------
    model = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias).to(device)
    ckpt_path = f"/content/drive/MyDrive/URPS/Models/sc_model_{r}.pt"
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    print(f"Loaded model checkpoint: sc_model_{r}.pt")

    # -----------------------------------
    # Step 2: Determine evaluation parameters.
    # -----------------------------------
    # For example, if in round 0 (the base model) we evaluate on 11-digit task,
    # then for round r, you might evaluate on (10 + r + 1)-digit tasks.
    # Here we set current_si_round to 1 for r==0 and r for r>=1.
    current_si_round = r if r > 0 else 1
    digit_step = 10 + current_si_round + 1  # e.g., 11 for round 0, 12 for round 1, etc.

    # -----------------------------------
    # Step 3: Generate a batch of SI data using the current model.
    # -----------------------------------
    # Generate a sample of 1000 prompts. The prompt length is determined by (original + current_si_round).
    sample_prompts = [generate_prompt_OOD(current_si_round, 'copy', original=10) for _ in range(50000)]
    encoded_prompts = [encode(p) for p in sample_prompts]
    prompt_tensor = torch.tensor(encoded_prompts, dtype=torch.long, device=device)
    # Generate outputs using the current model.
    batch_size_generate = 1024  # or smaller, adjust as needed
    outputs = []
    for i in range(0, prompt_tensor.shape[0], batch_size_generate):
        batch_outputs = generate(model=model, idx=prompt_tensor[i:i + batch_size_generate], max_new_tokens=35, top_k=1)
        outputs.extend(batch_outputs)

    # -----------------------------------
    # Step 4: Save the SI data to a temporary file.
    # -----------------------------------
    temp_si_data_file = f"/content/drive/MyDrive/URPS/Data/temp_si_data_round_{r}.txt"
    with open(temp_si_data_file, "w", encoding="utf-8") as f:
         f.writelines([line + "&\n" for line in outputs])
    print(f"Saved temporary SI data for round {r} to {temp_si_data_file}")

    # -----------------------------------
    # Step 5: Extract and save wrong answers.
    # -----------------------------------
    # The helper function `save_wrong_answers` will compare the expected portion (first current_si_round+10 characters)
    # with the generated portion (the next current_si_round+10 characters after an offset) and save the wrong ones.
    wrong_file = save_wrong_answers(temp_si_data_file, current_si_round)
    # wrong_file will be something like "wrong_answers_round_{current_si_round}.txt"


Loaded model checkpoint: sc_model_0.pt
Saved temporary SI data for round 0 to /content/drive/MyDrive/URPS/Data/temp_si_data_round_0.txt
Round 1: Saved 16 wrong answers to /content/drive/MyDrive/URPS/Data/wrong_answers_round_1.txt
Loaded model checkpoint: sc_model_1.pt
Saved temporary SI data for round 1 to /content/drive/MyDrive/URPS/Data/temp_si_data_round_1.txt
Round 1: Saved 8 wrong answers to /content/drive/MyDrive/URPS/Data/wrong_answers_round_1.txt
Loaded model checkpoint: sc_model_2.pt
Saved temporary SI data for round 2 to /content/drive/MyDrive/URPS/Data/temp_si_data_round_2.txt
Round 2: Saved 8 wrong answers to /content/drive/MyDrive/URPS/Data/wrong_answers_round_2.txt
Loaded model checkpoint: sc_model_3.pt
Saved temporary SI data for round 3 to /content/drive/MyDrive/URPS/Data/temp_si_data_round_3.txt
Round 3: Saved 99 wrong answers to /content/drive/MyDrive/URPS/Data/wrong_answers_round_3.txt
Loaded model checkpoint: sc_model_4.pt
Saved temporary SI data for round 4 to /con

In [None]:
def test_wrong_answers_accuracy(model, wrong_file, si_round):
    """
    Evaluate the model's performance on wrong answer samples.

    Each line in wrong_file is assumed to have the format:
         <expected>=<generated>&\n
    where <expected> is the correct answer (of length si_round+10) and
    <generated> is the model-generated answer (of length si_round+10) following an "=".

    The function constructs a prompt as <expected> + "=",
    then generates an output from the model and extracts the generated portion.
    It then compares this new generated answer with <expected> and computes
    the fraction of samples for which the model now produces the correct expected answer.

    Parameters:
      model: The GPT model used for generation.
      wrong_file: Path to the file containing wrong answer samples.
      si_round: The self-improvement round number used to determine expected lengths.

    Returns:
      accuracy: The fraction of samples that the model now corrects.
    """
    # Read all wrong-answer lines from the file.
    with open(wrong_file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    total = len(lines)
    if total == 0:
        print("No wrong answer samples found.")
        return 0.0

    correct_count = 0
    # Loop through each wrong answer sample.
    for line in lines:
        # Extract the expected answer.
        expected = line[:(si_round+10)]
        # Construct the prompt for generation.
        prompt = expected + "="
        # Encode the prompt.
        prompt_ids = encode(prompt)
        prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
        # Generate a new output using the model.
        new_output = generate(model=model, idx=prompt_tensor, max_new_tokens=35, top_k=1)[0]
        # Extract the generated part (assumed to be of length si_round+10 immediately after the '=' token).
        new_generated = new_output[(si_round+10+1):(si_round+10+1+si_round+10)]
        # If the new generated answer matches the expected answer, count it as corrected.
        if new_generated == expected:
            correct_count += 1

    accuracy = correct_count / total
    print(f"Evaluated {total} wrong samples; model corrected {correct_count} of them. Accuracy: {accuracy:.4f}")
    return accuracy


In [None]:
# For each round T (from 1 to 9, for example),
# evaluate the wrong_answers_round_T.txt file using the checkpoints from round T+1 to round 10.
for t in range(1, 10):  # You can adjust this range as needed.
    wrong_eval_file = f"/content/drive/MyDrive/URPS/Data/wrong_answers_round_{t}.txt"
    print(f"\nEvaluating wrong answers from round {t}:")

    # Loop over subsequent checkpoints
    for r in range(t+1, 11):
        model = GPT(vocab_size, block_size, n_embd, n_layer, n_head, dropout, bias).to(device)
        ckpt_path = f"/content/drive/MyDrive/URPS/Models/sc_model_{r}.pt"
        model.load_state_dict(torch.load(ckpt_path, map_location=device))

        # Evaluate the model on the wrong answers from round t.
        acc = test_wrong_answers_accuracy(model, wrong_eval_file, t)
        print(f"  Model sc_model_{r} accuracy on wrong answers from round {t}: {acc:.4f}")



Evaluating wrong answers from round 1:
Evaluated 8 wrong samples; model corrected 7 of them. Accuracy: 0.8750
  Model sc_model_2 accuracy on wrong answers from round 1: 0.8750
Evaluated 8 wrong samples; model corrected 4 of them. Accuracy: 0.5000
  Model sc_model_3 accuracy on wrong answers from round 1: 0.5000
Evaluated 8 wrong samples; model corrected 6 of them. Accuracy: 0.7500
  Model sc_model_4 accuracy on wrong answers from round 1: 0.7500
Evaluated 8 wrong samples; model corrected 6 of them. Accuracy: 0.7500
  Model sc_model_5 accuracy on wrong answers from round 1: 0.7500
Evaluated 8 wrong samples; model corrected 7 of them. Accuracy: 0.8750
  Model sc_model_6 accuracy on wrong answers from round 1: 0.8750
Evaluated 8 wrong samples; model corrected 6 of them. Accuracy: 0.7500
  Model sc_model_7 accuracy on wrong answers from round 1: 0.7500
Evaluated 8 wrong samples; model corrected 7 of them. Accuracy: 0.8750
  Model sc_model_8 accuracy on wrong answers from round 1: 0.8750
E

In [None]:
import plotly.graph_objects as go

# Example evaluation data.
# For each wrong answer round T, we have:
#   - A list of checkpoints (from T+1 to 10)
#   - A corresponding list of accuracies (fraction corrected)
evaluation_data = {
    1: ([2, 3, 4, 5, 6, 7, 8, 9, 10], [0.8750, 0.5000, 0.7500, 0.7500, 0.8750, 0.7500, 0.8750, 1.0000, 0.8750]),
    2: ([3, 4, 5, 6, 7, 8, 9, 10], [0.8750, 0.7500, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]),
    3: ([4, 5, 6, 7, 8, 9, 10], [0.9192, 0.9697, 0.9596, 0.9798, 0.9899, 0.9798, 0.9798]),
    4: ([5, 6, 7, 8, 9, 10], [0.8060, 0.8955, 0.9104, 0.9701, 0.9851, 0.9851]),
    5: ([6, 7, 8, 9, 10], [0.8750, 0.8750, 1.0000, 0.9750, 0.9750]),
    6: ([7, 8, 9, 10], [0.6512, 0.7674, 0.8140, 0.7907]),
    7: ([8, 9, 10], [0.6750, 0.8375, 0.7000]),
    8: ([9, 10], [0.6081, 0.5541]),
    9: ([10], [0.2561])
}

# Create an interactive Plotly graph.
fig = go.Figure()

for t, (checkpoints, accuracies) in evaluation_data.items():
    fig.add_trace(go.Scatter(
        x=checkpoints,
        y=accuracies,
        mode='lines+markers',
        name=f"Wrong Answers from Round {t}",
        hovertemplate="Checkpoint: %{x}<br>Accuracy: %{y:.4f}<extra></extra>"
    ))

fig.update_layout(
    title="Model Accuracy on Wrong Answers Across Rounds",
    xaxis_title="Model Checkpoint Round",
    yaxis_title="Accuracy on Wrong Answers",
    hovermode="x unified",
    template="plotly_white"
)

fig.show()
