<a href="https://colab.research.google.com/github/burtsev/CoopEvo/blob/master/notebooks/RMT_GPT_Multi_Stream.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets transformers
! pip install wandb
! git clone https://github.com/burtsev/RMT-experiments
#%cd RMT-experiments
#! ls

In [None]:
import numpy as np
import os
import sys
import tqdm
import torch
import datasets
import json
import wandb
from matplotlib import pyplot as plt
from transformers import AutoTokenizer, AutoConfig
from itertools import chain
from torch.utils.data import DataLoader#, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
sys.path.append('RMT-experiments')
#sys.path.append('..')
wandb.login(key='e7a6323eda0d0dfb427e61f332a5eb3b151c7bab')

### Load model

In [None]:
torch.cuda.empty_cache()

In [None]:
from base_models.modeling_gpt_neox_multi_str import GPTNeoXForCausalLM

model_name = 'EleutherAI/pythia-70m-deduped'
config_name = 'neox_6l4hd1024'
config_path = '/content/RMT-experiments/base_models/configs/gptconfigs/' + config_name + '.json'
with open(config_path, 'r') as file:
    wb_cfg = json.load(file)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_cfg = AutoConfig.from_pretrained(config_path)
model = GPTNeoXForCausalLM(config=model_cfg)

In [None]:
input_size = 512
memory_size = 0
n_segments = 1
batch_size = 8

block_size = input_size
block_size -= 2 * memory_size
history_size = (n_segments - 1) * block_size

### Prepare dataset

In [None]:
def group_texts(examples, block_size, history_size=None):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    if history_size is None:
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    else:
        result = {
            k: [t[max({0, i - history_size}) : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
def collate_fn(batch):
    input_ids = [torch.tensor(b['input_ids'][::-1]) for b in batch]
    labels = [torch.tensor(b['labels'][::-1]) for b in batch]
    attention_mask = [torch.tensor(b['attention_mask'][::-1]) for b in batch]
    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
    labels = pad_sequence(labels, padding_value=-100).T.flip(1)
    attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

    collated = {'input_ids': input_ids,
                'labels': labels,
                'attention_mask': attention_mask}

    if input_ids.shape[1] != block_size:
        labels_mask = torch.ones_like(input_ids, dtype=bool)
        labels_mask[:, :-block_size] = False
        collated['labels_mask'] = labels_mask

    return collated

In [None]:
task_name = 'wikitext-103-v1' #'wikitext-2-v1'
raw_datasets = datasets.load_dataset('wikitext', task_name)
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

train_dataset = tokenized_datasets["train"].map(lambda x: group_texts(x, block_size, history_size),
                                                        batched=True, desc=f"Grouping train in chunks of {block_size} and history {history_size}")
valid_dataset = tokenized_datasets["validation"].map(lambda x: group_texts(x, block_size, history_size),
                                                        batched=True, desc=f"Grouping valid in chunks of {block_size}")

In [None]:
train_rnd_generator = torch.Generator()
train_rnd_generator.manual_seed(42)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn,
                                shuffle=True, drop_last=False, generator=train_rnd_generator, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size,
                                        collate_fn=collate_fn, shuffle=False, drop_last=True, pin_memory=True)

In [None]:
# Create an iterator from the DataLoader
gen = iter(train_dataloader)

# Define the batch number you want to print
target_batch_number = 111
current_batch_number = 0

# Iterate over the DataLoader
for batch in gen:
    if 'labels_mask' in batch:
        batch.pop('labels_mask')
    if current_batch_number == target_batch_number:
        # Move the batch to the device (e.g., CPU or GPU)
        for k, v in batch.items():
            batch[k] = v.to(device)

        # Print the content of the specific batch
        print(f"Content of Batch {current_batch_number}:")
        for key, value in batch.items():
            print(f"\n{key}:")
            if key == 'labels' and isinstance(value, torch.Tensor):
                # Decode each sequence in the tensor
                for i, seq in enumerate(value):
                    decoded_seq = tokenizer.decode(seq, skip_special_tokens=True)
                    print(f"Decoded Sequence {i} in {key}: {decoded_seq}")
            elif isinstance(value, torch.Tensor):
                print(f"{key} Tensor: {value}")
            else:
                print("[Not a tensor]", value)

        # Check and print 'labels_mask' if it exists
        if 'labels_mask' in batch:
            print("\nlabels_mask:")
            print(batch['labels_mask'])
        else:
          print('No labels_mask')

        # Stop after finding and printing the desired batch
        break

    # Increment the batch number
    current_batch_number += 1


### Create model

In [None]:
gen = iter(train_dataloader)
batch = next(gen)
if 'labels_mask' in batch:
    batch.pop('labels_mask')
for k, v in batch.items():
    batch[k] = v.to(device)
batch['input_ids'].shape

### Add RMT

In [None]:
from modeling_rmt.language_modeling import MemoryCell, RecurrentWrapper

cell = MemoryCell(model, num_mem_tokens=memory_size)
model = RecurrentWrapper(cell,
                        segment_size=block_size,
                        max_n_segments=n_segments,
                        )
model.to(device)

In [None]:
try:
    out = model(**batch)
    print('Success!')
except IndexError:
    print('Error: Input size too large!')

### Train the model

In [None]:
from torch.optim import AdamW

learning_rate = 1e-04
optim = AdamW(params=model.parameters(), lr=learning_rate)

In [None]:
train_steps = 20000
eval_steps = 50

train_gen = iter(train_dataloader)
valid_gen = iter(valid_dataloader)

In [None]:
run_cfg = {
    'input_size': block_size,
    'memory_size': memory_size,
    'n_segments': n_segments,
    'batch_size': batch_size,
    'model_name': model_name,
    'config_name': config_name,
    'learning rate': learning_rate,
}
wb_cfg.update(run_cfg)

run = wandb.init(
    project="RMT GPT",
    name='mem'+str(memory_size)+'_inlen'+str(block_size)+'_seg'+str(n_segments)+'_Multi_Str_S0S2D2_'+config_name+'_'+task_name,
    config=wb_cfg
)

In [None]:
losses = []
# Initialize the progress bar
progress_bar = tqdm.notebook.tqdm(range(train_steps), desc='Training Progress')

# Create an iterator from the DataLoader
train_iterator = iter(train_dataloader)

for step in progress_bar:
    optim.zero_grad()

    try:
        batch = next(train_iterator)
    except StopIteration:
        # Reset the iterator when the end of the dataset is reached
        train_iterator = iter(train_dataloader)
        batch = next(train_iterator)

    # Move the batch to the device
    for k, v in batch.items():
        batch[k] = v.to(device)

    out = model(**batch)
    loss = out.loss

    loss.backward()
    optim.step()

    if step % eval_steps == 0:
        losses.append(loss.detach().item())
        current_loss = loss.item()
        progress_bar.set_description(f"Step {step}/{train_steps} - Loss: {current_loss:.4f}")
        wandb.log({'step': step, 'loss': current_loss})

wandb.finish()

In [None]:
plt.plot(losses, label='Baseline 0 mem')
plt.xlabel('step')
plt.ylabel('train loss')
plt.legend()
plt.show()

In [None]:
loss_base = losses
print(memory_size)

# Mem run

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_cfg = AutoConfig.from_pretrained(config_path)
model = GPTNeoXForCausalLM(config=model_cfg)

In [None]:
input_size = 4
memory_size = 2
#n_segments = 2
#batch_size = 32

block_size = input_size
block_size -= 2 * memory_size
history_size = (n_segments - 1) * block_size

### Prepare dataset

In [None]:
def group_texts(examples, block_size, history_size=None):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    if history_size is None:
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    else:
        result = {
            k: [t[max({0, i - history_size}) : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
    result["labels"] = result["input_ids"].copy()
    return result

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
def collate_fn(batch):
    input_ids = [torch.tensor(b['input_ids'][::-1]) for b in batch]
    labels = [torch.tensor(b['labels'][::-1]) for b in batch]
    attention_mask = [torch.tensor(b['attention_mask'][::-1]) for b in batch]
    input_ids = pad_sequence(input_ids, padding_value=id_pad_value).T.flip(1)
    labels = pad_sequence(labels, padding_value=-100).T.flip(1)
    attention_mask = pad_sequence(attention_mask, padding_value=0).T.flip(1)

    collated = {'input_ids': input_ids,
                'labels': labels,
                'attention_mask': attention_mask}

    if input_ids.shape[1] != block_size:
        labels_mask = torch.ones_like(input_ids, dtype=bool)
        labels_mask[:, :-block_size] = False
        collated['labels_mask'] = labels_mask

    return collated

In [None]:
task_name = 'wikitext-2-v1'
raw_datasets = datasets.load_dataset('wikitext', task_name)
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

tokenized_datasets = raw_datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=column_names,
    desc="Running tokenizer on dataset",
)

train_dataset = tokenized_datasets["train"].map(lambda x: group_texts(x, block_size, history_size),
                                                        batched=True, desc=f"Grouping train in chunks of {block_size} and history {history_size}")
valid_dataset = tokenized_datasets["validation"].map(lambda x: group_texts(x, block_size, history_size),
                                                        batched=True, desc=f"Grouping valid in chunks of {block_size}")

In [None]:
train_rnd_generator = torch.Generator()
train_rnd_generator.manual_seed(42)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn,
                                shuffle=True, drop_last=False, generator=train_rnd_generator, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size,
                                        collate_fn=collate_fn, shuffle=False, drop_last=True, pin_memory=True)

### Create model

In [None]:
gen = iter(train_dataloader)
batch = next(gen)
batch.pop('labels_mask')
for k, v in batch.items():
    batch[k] = v.to(device)
batch['input_ids'].shape

### Add RMT

In [None]:
from modeling_rmt.language_modeling import MemoryCell, RecurrentWrapper

cell = MemoryCell(model, num_mem_tokens=memory_size)
model = RecurrentWrapper(cell,
                        segment_size=block_size,
                        max_n_segments=n_segments,
                        )
model.to(device)

In [None]:
try:
    out = model(**batch)
    print('Success!')
except IndexError:
    print('Error: Input size too large!')

In [None]:
from torch.optim import AdamW
optim = AdamW(params=model.parameters(), lr=1e-03)

In [None]:
#train_steps = 2000
#eval_steps = 100

train_gen = iter(train_dataloader)
valid_gen = iter(valid_dataloader)

In [None]:
run_cfg = {
    'input_size': block_size,
    'memory_size': memory_size,
    'n_segments': n_segments,
    'batch_size': batch_size,
    'model_name': model_name,
    'config_name': config_name,
}
wb_cfg.update(run_cfg)

run = wandb.init(
    project="RMT GPT",
    name='mem'+str(memory_size)+'_inlen'+str(input_size)+'_seg'+str(n_segments)+'_MultiStrS1W2'+config_name,
    config=wb_cfg
)

In [None]:
losses = []
# Initialize the progress bar
progress_bar = tqdm.notebook.tqdm(range(train_steps), desc='Training Progress')

# Create an iterator from the DataLoader
train_iterator = iter(train_dataloader)

for step in progress_bar:
    optim.zero_grad()

    try:
        batch = next(train_iterator)
    except StopIteration:
        # Reset the iterator when the end of the dataset is reached
        train_iterator = iter(train_dataloader)
        batch = next(train_iterator)

    # Move the batch to the device
    for k, v in batch.items():
        batch[k] = v.to(device)

    out = model(**batch)
    loss = out.loss

    loss.backward()
    optim.step()

    if step % eval_steps == 0:
        losses.append(loss.detach().item())
        current_loss = loss.item()
        progress_bar.set_description(f"Step {step}/{train_steps} - Loss: {current_loss:.4f}")
        wandb.log({'step': step, 'loss': current_loss})

wandb.finish()

In [None]:
plt.plot(loss_base, label='Mem 0',alpha=0.5)
plt.plot(losses, label='Mem ' + str(memory_size),alpha=0.5)
plt.xlabel('step')
plt.ylabel('train loss')
plt.legend()
plt.show()

In [None]:
#loss_2seg0mem = losses
print(memory_size)

In [None]:
valid_losses = []
model.eval()
for step in tqdm.notebook.tqdm(range(eval_steps)):
    batch = next(valid_gen)
    for k, v in batch.items():
        batch[k] = v.to(device)

    with torch.no_grad():
        out = model(**batch)
    valid_loss = out.loss

    valid_losses.append(loss.detach().item())

In [None]:
print(f'Loss on {eval_steps * batch_size} validation samples: {np.mean(valid_losses)}')