In [None]:
!huggingface-cli login --token $TOKEN$ --add-to-git-credential 

In [None]:
cd src

In [None]:
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from configuration_mamba import MambaConfig
from datasets import Dataset, load_dataset
from modeling_mamba import MambaForCausalLM, MambaModel
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

In [None]:
def load_json(json_path):
    with open(json_path, "r") as json_file:
        loaded_data = json.load(json_file)
    return loaded_data


def load_data(data_path):
    return load_dataset(data_path)


def load_model(config):
    return MambaForCausalLM(config)


def load_tokenizer(path):
    return AutoTokenizer.from_pretrained(path)


def make_config(json):
    config = MambaConfig(
        vocab_size=json["vocab_size"],
        d_model=json["d_model"],
        d_conv=json["d_conv"],
        expand=json["expand"],
        conv_bias=json["conv_bias"],
        bias=json["bias"],
        n_layer=json["n_layer"],
        dt_rank=json["dt_rank"],
        pad_vocab_size_multiple=json["pad_vocab_size_multiple"],
        initializer_range=json["initializer_range"],
    )
    return config


def split_data(data):
    train_size = int(len(data) * 0.8)
    valid_size = len(data) - train_size

    ds_train = data.select(list(range(train_size)))
    ds_valid = data.select(list(range(train_size, train_size + valid_size)))

    return DatasetDict({"train": ds_train, "valid": ds_valid})


def tokenize(data):
    outputs = tokenizer(
        data["tgt"],
        truncation=True,
        max_length=1024,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length != 0:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids)[0]
        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
        return lm_loss

In [None]:
def pretrain(json, train_args):
    config_data = load_json(json)
    model_config = make_config(config_data)
    model = load_model(model_config)
    tok = load_tokenizer(config_data["tokenizer_path"])

    data = load_data(config_data["data"])
    #     data = split_data(data)
    tokenized_data = data.map(
        tokenize, batched=True, remove_columns=data["train"].column_names
    )
    trainer = MambaTrainer(
        model=model,
        tokenizer=tok,
        args=train_args,
        train_dataset=tokenized_data["train"],
        eval_dataset=tokenized_data["valid"],
    )
    trainer.train()
    model.push_to_hub(config_data["upload_path"])

In [None]:
train_args = TrainingArguments(
    output_dir="mamba",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    evaluation_strategy="steps",
    num_train_epochs=4,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
)
tokenizer = AutoTokenizer.from_pretrained("google/byt5-large")
pretrain("model_parameters.json", train_args)

In [None]:
##CONFIG GENERATION

In [None]:
import json

data = {
    "vocab_size": 7000,
    "d_model": 256,
    "d_conv": 4,
    "expand": 2,
    "conv_bias": True,
    "bias": False,
    "n_layer": 4,
    "dt_rank": "auto",
    "pad_vocab_size_multiple": 8,
    "initializer_range": 0.02,
    "tokenizer_path": "google/byt5-large",
    "upload_path": "mlsquare/samantar_merged_with_train_val",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

# Convert the dictionary to JSON format
json_data = json.dumps(data, indent=4)

# Save the JSON data to a file
with open("model_parameters.json", "w") as json_file:
    json_file.write(json_data)

with open("model_parameters.json", "r") as json_file:
    loaded_data = json.load(json_file)

# Print the loaded data
print(loaded_data)