In [None]:
from google.colab import drive
import os
import numpy as np
import sys
import pickle

drive.mount("/content/drive")
sys.path.append("/content/drive/MyDrive/Ethz/CSNLP/csnlp_project/csnlp-dataset-distillation-main")

In [None]:
!pip install transformers[torch] datasets

In [None]:
from distillation_trainer import DistillationTrainer
from state import State
from data import Datapoint, get_wiki_dataloader
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, GPT2Config, TrainingArguments, Trainer
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import random
from transformers import AutoConfig, AutoModelForCausalLM
from copy import deepcopy
import math

# from VAE import VAE
from torch.utils.data import DataLoader
import datasets
from sklearn.model_selection import train_test_split
from networks.language_model import LanguageModelWrapper

# print(torch.cuda.memory_summary())

In [None]:
# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
config = GPT2Config(vocab_size=50257, n_positions=1024, n_ctx=1024, n_embd=128, n_layer=4, n_head=4)
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2").to(device)
print(f"device is {device}")

In [None]:
block_size = 64  # tokens in each generated sentence
min_text_length = 128

dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
dataset_train = dataset["train"]
texts = []
for data_point in dataset_train:
    if len(data_point["text"]) > min_text_length:
        texts.append(data_point["text"])
    # if len(texts) >= L:
    #     break
# dataset_nonempty = [dataset_train[i] for i in range(len(dataset_train)) if len(dataset_train[i]['text']) > 200]
# k = len(dataset_nonempty)
# texts = [dataset_nonempty[i]['text'] for i in range(k)]

In [None]:
fraction = 0.005
L1 = int(2.0 * fraction * len(texts))
print(L1)

In [None]:
total_len = len(texts)
LEN_sample = L1 // 2
np.random.seed(42)
idxs = np.random.choice(range(total_len), LEN_sample, replace=False)

samples_sentences = [texts[idx] for idx in idxs]

In [None]:
cropped_tokens = []

for text in samples_sentences:
    tokens = tokenizer(text)["input_ids"]
    cropped_tokens += [tokens[i : i + block_size] for i in range(0, len(tokens), block_size)]
    if len(cropped_tokens[-1]) < block_size:
        cropped_tokens.pop(-1)

In [None]:
L2 = len(cropped_tokens)
print(L2, fraction * len(texts))

We should generate at least 1% -> around 14k sentences, i.e. ~219 batches of 64 sentences.

Ideally we should be able to reach 10% -> around 140k sentences i.e. ~2k batches

Note: the dataloader we pass to train the sentences should have at least twice the number of sentences we will generate. Take into account that an original sentence is around 2 times the number of blocks it has. So for 1% we would need 14k original sentences for each update. That takes around 17minutes per generated batch. i.e. 62h

In [None]:
seq_len = block_size
batch_size = 16
train_loader = get_wiki_dataloader(texts[:L1], tokenizer, block_size, batch_size)

In [None]:
samples_sentences = cropped_tokens
extra = len(samples_sentences) % 64
samples_dataloader = torch.tensor(samples_sentences)[:-extra].reshape(-1, 64, 64)
# samples_dataloader = DataLoader(samples_sentences, batch_size=64, shuffle=False)

In [None]:
state = State(
    device=device,
    distill_steps=1,
    distill_epochs=1,
    batch_size=batch_size,
    seq_len=seq_len,
    distill_lr=3e-4,
    lr=1e-2,
    vocab_size=len(tokenizer),
    decay_epochs=2,
    decay_factor=0.1,
    epochs=1,
    checkpoint_interval=1,
)

In [None]:
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2").to(device)
model = LanguageModelWrapper(gpt2_model, state)

In [None]:
trainer = DistillationTrainer(state, model, tokenizer, train_loader, sampled_sentences=samples_dataloader)

In [None]:
trainer.iterative_generation()

Check how is the last generated sentence going

In [None]:
resulting_text, token_ids, labels = trainer.get_train_data_and_text()
print(tokenizer.decode(samples_sentences[0]))

Check examples

In [None]:
with open("generated_text_backup", "rb") as fp:
    gen_text = pickle.load(fp)

for i in range(5):
    print("Original sentence:", tokenizer.decode(samples_sentences[i]))
    print("Generated sentence: ", gen_text[i])

Evaluation

In [None]:
import pickle

with open("generated_data", "rb") as fp:
    gen_data = pickle.load(fp)

print(len(gen_data), len(samples_sentences))

In [None]:
def evaluate(training_dataset, test_dataset, model=None):
    model_checkpoint = "distilgpt2"
    config = AutoConfig.from_pretrained(model_checkpoint)
    eval_model = model if model is not None else AutoModelForCausalLM.from_config(config)

    training_args = TrainingArguments(
        "output",
        evaluation_strategy="no",
        learning_rate=2e-5,
        weight_decay=0.01,
        num_train_epochs=1,
        save_strategy="no",
        report_to="none",
    )

    eval_trainer = Trainer(
        model=eval_model,
        args=training_args,
        train_dataset=training_dataset,
    )

    eval_trainer.train()

    training_args = TrainingArguments(
        "output",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        weight_decay=0.01,
        num_train_epochs=1,
        save_strategy="no",
        report_to="none",
    )

    eval_trainer = Trainer(
        model=eval_model,
        args=training_args,
        train_dataset=None,
        eval_dataset=test_dataset,
    )

    eval_results = eval_trainer.evaluate()
    print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

In [None]:
def update_eval_dataset(evaluation_dataset):
    def tokenize_function(examples):
        return tokenizer(examples["text"])

    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    evaluation_dataset = evaluation_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

    evaluation_dataset = evaluation_dataset.map(
        group_texts,
        batched=True,
        batch_size=1000,
        num_proc=4,
    )

    return evaluation_dataset

In [None]:
for fraction in [0.1, 0.2, 0.4, 0.8, 1]:
    new_length = int(len(gen_data) * fraction)

    synthetic_dataset = datasets.Dataset.from_dict(
        {"input_ids": gen_data[:new_length], "labels": gen_data[:new_length]}
    )
    sampled_dataset = datasets.Dataset.from_dict(
        {"input_ids": samples_sentences[:new_length], "labels": samples_sentences[:new_length]}
    )

    eval_L = 10000
    evaluation_dataset = datasets.Dataset.from_dict({"text": texts[:eval_L]})
    evaluation_dataset = update_eval_dataset(evaluation_dataset)

    print(f"Total fraction is {round(fraction * 0.05, 4)} %")

    evaluate(synthetic_dataset, evaluation_dataset)
    evaluate(sampled_dataset, evaluation_dataset)

In [None]:
import torch.nn as nn


class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids, labels=None):
        input_ids = input_ids.reshape(-1, input_ids.shape[-1])
        embeddings = self.embedding(input_ids)
        lstm_out, _ = self.lstm(embeddings)
        logits = self.fc(lstm_out)
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return {"loss": loss, "logits": logits}

In [None]:
lstm_model = LSTMModel(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=128,
    hidden_dim=256,
    num_layers=2,
)

In [None]:
for fraction in [0.1, 0.2, 0.4, 0.8, 1]:
    new_length = int(len(gen_data) * fraction)

    synthetic_dataset = datasets.Dataset.from_dict(
        {"input_ids": gen_data[:new_length], "labels": gen_data[:new_length]}
    )
    sampled_dataset = datasets.Dataset.from_dict(
        {"input_ids": samples_sentences[:new_length], "labels": samples_sentences[:new_length]}
    )

    eval_L = 10000
    evaluation_dataset = datasets.Dataset.from_dict({"text": texts[:eval_L]})
    evaluation_dataset = update_eval_dataset(evaluation_dataset)

    print(f"Total fraction is {round(fraction * 0.05, 4)} %")

    evaluate(synthetic_dataset, evaluation_dataset, lstm_model)
    evaluate(sampled_dataset, evaluation_dataset, lstm_model)