In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '8'
os.chdir('/opt/project/')

In [None]:
from typing import List, Tuple, Mapping, Callable
import gzip
import math
import sys
import pickle
import torch
from absl import flags
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch import nn
from torch import optim
from transformers import AutoTokenizer
from bertviz import head_view, model_view

In [None]:
import scripts.train_model
from scripts.train_model import ReadOffValues, get_tokenizer
from accelerate import Accelerator
import src.utils as utils

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set()
plt.style.use('.mplstyle')
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150
plt.rcParams['figure.figsize'] = 12,8

In [None]:
FLAGS = flags.FLAGS
try:
    flags.DEFINE_string(
        "save_encodings_path", None, help="Where to load embeddings from"
    )
except:
    pass

In [None]:
args = (
    "--disable_tqdm "
    "--max_generation_len 42 "
    "--gaccum 1 "
    "--model=EleutherAI/gpt-j-6B "
    "--model_type=PromptTuningPostfixLM "
    "--seed=0 "
    "--expdir=exps/arithmetic_bugfix/seed_0/PromptTuningPostfixLM/EleutherAI/gpt-j-6B/step_10/lr_0.001 "
    "--logdir=exps/arithmetic_bugfix/seed_0/PromptTuningPostfixLM/EleutherAI/gpt-j-6B/step_10/lr_0.001 "
    "--learning_rate=0.001 "
    "--dataset=ArithmethicDataset "
    "--N_per_digit=90 "
    "--n_prompt_tokens=10 "
    "--n_coder_steps=10 "
    "--resume_from_checkpoint"
    " exps/arith_what_is/seed_0/PromptTuningPostfixLM/EleutherAI/gpt-j-6B/step_10/lr_0.001/checkpoints/iter-15.pth.tar "
    "--save_encodings_path"
    " exps/arith_what_is/seed_0/PromptTuningPostfixLM/EleutherAI/gpt-j-6B/step_10/lr_0.001/checkpoints/train_reads.pickle "
)
sys.argv = sys.argv[:11] + args.split(" ")
FLAGS._parse_args(sys.argv, True)


In [None]:
with gzip.open(FLAGS.save_encodings_path, 'rb') as handle:
    train_reads = pickle.load(handle)

In [None]:
tokenizer, padding_idx = get_tokenizer(FLAGS.model)
FLAGS.padding_idx = padding_idx

In [None]:
def read_to_tokens(read: ReadOffValues, 
                   tokenizer, 
                   n_prompt_tokens:int=10, 
                   n_coder_steps:int=10) -> Tuple[torch.tensor, torch.tensor, List[torch.tensor]]:
    
    input = tokenizer.convert_ids_to_tokens(read.input)
    input = (['pre'+str(i) for i in range(n_prompt_tokens)] + 
             input + 
             ['post'+str(i) for i in range(n_coder_steps)])
    output = tokenizer.convert_ids_to_tokens(read.output)
    attentions = [torch.from_numpy(read.attentions[0][k][None, :, :, :]) 
                  for k in range(len(read.attentions[0]))]
    return input, output, attentions
    

In [None]:
input, output, attentions = read_to_tokens(
    train_reads[1],
    tokenizer,
    n_prompt_tokens=FLAGS.n_prompt_tokens,
    n_coder_steps=FLAGS.n_coder_steps,
)

head_view(attentions, input)

In [None]:
class ProbingDataset(Dataset):
    split_ratios: List = [("train", 1.0), ("dev", 0.0), ("test", 0.0)]

    def __init__(self, reads: List[ReadOffValues], **kwargs):
        for (k, v) in kwargs.items():
            self.__setattr__(k, v)

        self.data = self.get_data(reads, **kwargs)

    def __len__(self) -> int:
        return len(self.data)

    def count_carry(self, a, b):
        carry = 0
        count = 0
        # Initialize len_a and len_b
        # with the sizes of strings
        len_a = len(a)
        len_b = len(b)

        carry_ons = []
        values = []

        while len_a != 0 or len_b != 0:
            # Assigning the ascii value
            # of the character
            x = 0
            y = 0
            if len_a > 0:
                x = int(a[len_a - 1]) + int("0")
                len_a -= 1

            if len_b > 0:
                y = int(b[len_b - 1]) + int("0")
                len_b -= 1

            # Add both numbers/digits
            sum = x + y + carry

            # If sum > 0, increment count
            # and set carry to 1
            if sum >= 10:
                carry = 1
                count += 1
                value = sum % 10
            # Else, set carry to 0
            else:
                carry = 0
                value = sum

            carry_ons.append(carry)
            values.append(value)

        return carry_ons, values

    def extract_data(
        self, reads, seed: int = 0, layer: int = -1, split: str = "train"
    ):
        print(f"Layer: {layer}")
        rng = np.random.default_rng(seed)
        rng.shuffle(reads)
        data = []
        for read in reads:
            hiddens = read.hidden_states[0][layer]
            x1, x2 = read.input_str.replace("Q: What is ", "").replace("?\n", "").split(" plus ")
            x1 = x1.replace(" ", "")
            x2 = x2.replace(" =", "").replace(" ", "")
            carry_ons, values = self.count_carry(x1, x2)
            data.append((hiddens, carry_ons, values, read.input))
        return data

    def get_split(self, examples, split: str = "train"):
        L = len(examples)
        data = {}
        index = 0
        for (i, (split_name, ratio)) in enumerate(self.split_ratios):
            length = math.floor(L * ratio)
            if i != len(self.split_ratios) - 1:
                end_index = min(index + length, L)
            else:
                end_index = L
            data[split_name] = examples[index:end_index]
            index = end_index

        return data[split]

    def get_data(
        self, reads, split: str = "train", seed=0, **kwargs
    ) -> List[Tuple]:
        seed, new_seed = utils.split_seed(seed)
        examples = self.extract_data(reads, seed=new_seed, **kwargs)
        examples = self.get_split(examples, split=split)
        seed, new_seed = utils.split_seed(seed)
        rng = np.random.default_rng(new_seed)
        rng.shuffle(examples)
        return examples

    def __getitem__(self, index: int) -> Tuple:
        data = self.data[index]
        return tuple(map(torch.tensor, data))

    @staticmethod
    def get_collate(
        pad_token_id: int = -100,
        pre_token_id: int = -1,
        post_token_id: int = -2,
        n_prompt_tokens: int = 0,
        n_coder_steps: int = 0,
    ) -> Callable:
        def collate(data) -> Mapping[str, torch.Tensor]:
            carry_ons = [d[1] for d in data]
            values = [d[2] for d in data]
            carry_ons = pad_sequence(
                carry_ons, padding_value=-100, batch_first=True
            ).long()

            values = pad_sequence(
                values, padding_value=-100, batch_first=True
            ).long()

            hiddens = pad_sequence(
                [d[0] for d in data], batch_first=True
            ).float()
            
            inputs = pad_sequence(
                [
                    torch.concat(
                        [
                            torch.full((n_prompt_tokens,), pre_token_id),
                            d[3],
                            torch.full((n_coder_steps,), post_token_id),
                        ],
                        dim=0,
                    )
                    for d in data
                ],
                batch_first=True,
                padding_value=pad_token_id,
            ).long()

            mask = inputs == pad_token_id

            data = {
                "carry_ons": carry_ons,
                "values": values,
                "hiddens": hiddens,
                "mask": mask,
                "inputs": inputs,
            }

            return data

        return collate


In [None]:
with gzip.open(FLAGS.save_encodings_path.replace("train", "val"), 'rb') as handle:
    val_reads = pickle.load(handle)

In [None]:

train_dataset = ProbingDataset(train_reads, layer=-1)
val_dataset = ProbingDataset(val_reads, layer=-1)
collate_fn = ProbingDataset.get_collate(pad_token_id=tokenizer.pad_token_id,
                                        pre_token_id=tokenizer.vocab['pre'],
                                        post_token_id=tokenizer.vocab['post'],
                                        n_coder_steps=FLAGS.n_coder_steps,
                                        n_prompt_tokens=FLAGS.n_prompt_tokens)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [None]:
class SimpleAttention(nn.Module):
    def __init__(
        self,
        n_features,
        n_hidden,
        key=False,
        query=True,
        memory=False,
    ):
        super().__init__()
        self.key = key
        self.query = query
        self.memory = memory
        self.n_features = n_features
        self.n_hidden = n_hidden

        if self.key:
            self.make_key = nn.Linear(n_features, n_hidden)
        if self.query:
            self.make_query = nn.Linear(n_features, n_hidden)
        if self.memory:
            self.make_memory = nn.Linear(n_features, n_hidden)

        self.n_out = n_hidden

    def forward(self, hidden, features, mask=None):
        if self.key:
            key = self.make_key(features)
        else:
            key = features

        if self.memory:
            memory = self.make_memory(features)
        else:
            memory = features

        if self.query:
            query = self.make_query(hidden)
        else:
            query = hidden

        # attention
        # query = query.expand_as(key) # B x T x H

        query = query.unsqueeze(1).unsqueeze(0)
        key = key.unsqueeze(1)

        scores = (key * query).sum(dim=-1)

        if mask is not None:
            scores += (mask.unsqueeze(1) * -99999)

        distribution = F.softmax(scores, dim=2)
        weighted = memory.unsqueeze(1) * distribution.unsqueeze(-1)
        summary = weighted.sum(dim=2, keepdim=False)

        # value
        return summary, distribution


class ProbModel(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        max_digits: int = 8,
        dropout: float = 0.0,
        ignore_index: int = -100,
        w_init: float = 0.01,
    ):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.max_digits = max_digits
        self.hidden_dim = hidden_dim

        self.carry_on_selector = nn.Parameter(
            torch.randn(max_digits, hidden_dim) * w_init
        )
        self.carry_on_projector = nn.Linear(hidden_dim // 4, 10)

        self.carry_on_attention = SimpleAttention(
            n_features=hidden_dim,
            n_hidden=hidden_dim // 4,
            key=True,
            query=True,
            memory=True,
        )

        self.value_selector = nn.Parameter(
            torch.randn(max_digits, hidden_dim) * w_init
        )
        self.value_projector = nn.Linear(hidden_dim // 4, 10)

        self.value_attention = SimpleAttention(
            n_features=hidden_dim,
            n_hidden=hidden_dim // 4,
            key=True,
            query=True,
            memory=True,
        )
        self.loss = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')

    def forward(self, data, only_loss=False):
        hiddens = data["hiddens"]
        carry, carry_distribution = self.carry_on_attention(
            self.carry_on_selector, hiddens, mask=data["mask"]
        )
        value, value_distribution = self.carry_on_attention(
            self.value_selector, hiddens, mask=data["mask"]
        )
        carry = self.carry_on_projector(carry).transpose(2, 1)
        value = self.value_projector(value).transpose(2, 1)

        if only_loss:
            carry = carry[..., :data["carry_ons"].shape[-1]]
            value = value[..., :data["values"].shape[-1]]
                        
            return self.loss(carry, data["carry_ons"]) + self.loss(value, data["values"])
        else:
            return (carry, carry_distribution), (value, value_distribution)


In [None]:
model = ProbModel(hidden_dim=4096, max_digits=8)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
accelerator = Accelerator()
device = accelerator.device

In [None]:
model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader)

In [None]:
def train_loop(model, optimizer, accelerator, train_loader, epochs=20):
    total_loss = 0.0
    total_count = 0.0
    optimizer.zero_grad()
    model.train()
    for iter in range(epochs):
        for step, data in enumerate(train_loader):
            loss = model(data, only_loss=True)
            
            token_count = (data["carry_ons"] != -100).sum().item() + (
                data["values"] != -100
            ).sum()
            
            loss = loss/token_count
            
            accelerator.backward(loss)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            
            optimizer.step()
            optimizer.zero_grad()
   
            total_loss += loss.item() * token_count.item()
            total_count += token_count.item()

        avg_loss = total_loss / total_count
        print(f"train/loss/{iter}: {avg_loss}")


In [None]:
def eval_loop(model, val_loader, iter=0):
    total_loss = 0.0
    total_count = 0.0
    total_corrects = 0.0
    model.eval()
    for step, data in enumerate(val_loader):
        (carry, carry_distribution), (value, value_distribution) = model(
            data, only_loss=False
        )
        
        carry_preds = carry.argmax(dim=1)[..., :data["carry_ons"].shape[-1]]
        value_preds = value.argmax(dim=1)[..., :data["values"].shape[-1]]
        
        # print(value_preds)
        # print(data["values"])
                
        token_count = (
            data["values"] != -100
        ).sum().item()  + (data["carry_ons"] != -100).sum().item() 
         
        corrects = (
            data["values"] == value_preds
         ).sum().item() + (data["carry_ons"] == carry_preds).sum().item() 
        
        # total_loss += (loss.item() * token_count)
        total_corrects += corrects
        total_count += token_count

    # avg_loss = total_loss / total_count
    avg_accuracy = total_corrects / total_count
    # print(f"train/loss/{iter}: {avg_loss}")
    print(f"train/accuracy/{iter}: {avg_accuracy}")

In [None]:
train_loop(model, optimizer, accelerator, train_loader, epochs=50)

In [None]:
eval_loop(model, val_loader)

In [None]:
val_single_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
val_single_loader = accelerator.prepare(val_single_loader)

In [None]:
datas = [d for d in val_single_loader]

In [None]:
id=15
data = datas[id]
(carry, carry_distribution), (value, value_distribution) = model(
                data, only_loss=False
            )

print(carry_distribution.shape)
carry_distribution = carry_distribution[0, ...].cpu().detach().numpy()
value_distribution = value_distribution[0, ...].cpu().detach().numpy()
input = data['inputs'][0, ...]
input_str = [tokenizer.decode(t) for t in input]

In [None]:
def heatmap(data, row_labels, col_labels, title="", ax=None, **kwargs):
    ax = sns.heatmap(data, **kwargs)

    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticks(np.arange(data.shape[0]))
    # ... and label them with the respective list entries.
    ax.set_xticklabels(col_labels)
    ax.set_yticklabels(row_labels)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right")
    plt.title(title)

In [None]:
heatmap(carry_distribution, [str(i) for i in range(8)], input_str, "carry prob attentions")

In [None]:
heatmap(value_distribution, [str(i) for i in range(8)], input_str, "value prob attentions")