In [None]:
# !pip install git+https://github.com/adapter-hub/adapters.git
# !pip install wandb
# !pip install pandas
# !pip install datasets

# requires ipykernel package

In [None]:
# !pip install accelerate -U

In [None]:
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel, Seq2SeqTrainingArguments, BertTokenizer, Seq2SeqTrainer, AutoModel, AutoModelForCausalLM, DataCollatorForSeq2Seq, GenerationConfig, DataCollatorWithPadding
from adapters import BnConfig, Seq2SeqAdapterTrainer, AdapterTrainer, BertAdapterModel, init
import wandb
import torch
import pandas as pd
from datasets import Dataset
import os
import datasets
import numpy as np



In [None]:
# print device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

In [None]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained("Exscientia/IgBert", "Exscientia/IgBert")
init(model)

In [None]:
config = BnConfig(mh_adapter=True, output_adapter=True, reduction_factor=16, non_linearity="relu")

model.add_adapter("seq2seq_adapter", config=config)
model.set_active_adapters("seq2seq_adapter")
model.train_adapter("seq2seq_adapter")

In [None]:
#print(f"print EncoderDecoderModel: {model}")

# Load the tokenizer and model from Hugging Face
tokenizer = BertTokenizer.from_pretrained("Exscientia/IgBert")

In [None]:
generation_config = GenerationConfig(
    num_return_sequences=1,
    max_length=512,
    min_length=50,
    early_stopping = True,
    
    length_penalty = -2.0,
    
    num_beams = 4,

    # sampling
    do_sample=True,
    top_k=50,
    
    no_repeat_ngram_size = 3,

    # distribution adjustment
    temperature=0.001,
    repetition_penalty=1,

    vocab_size=model.config.encoder.vocab_size,

    # token ids
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.sep_token_id,
    decoder_start_token_id=tokenizer.cls_token_id,

    # others
    use_cache=True,
    output_logits=True,
    output_scores=True,
    output_hidden_states=True,
    return_dict_in_generate=True, )


In [None]:
generation_config.save_pretrained("generation_config", "generation_config_1.json")

In [None]:
generation_config_name = "generation_config_1"
generation_config = GenerationConfig.from_pretrained("generation_config", f"{generation_config_name}.json")

In [None]:
batch_size = 8
num_train_epochs = 5
learning_rate = 1e-4


# Set up the run name
run_name=f"new_small_data_with_adapters_batch_size_{batch_size}_epochs_{num_train_epochs}_automodel_lr_{learning_rate}_{generation_config_name}"

output_dir = f"./{run_name}"
logging_dir = f"./{run_name}_logging"

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    logging_dir=logging_dir,
    evaluation_strategy="steps",
    logging_strategy="steps",
    logging_steps=10,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    report_to="wandb",
    run_name=run_name,
    generation_config=generation_config,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Create directories if they do not exist
os.makedirs(training_args.output_dir, exist_ok=True)
os.makedirs(training_args.logging_dir, exist_ok=True)

# Log in to Weights & Biases
#wandb.login()


wandb.init(project="bert2bert-translation", name=run_name)

In [None]:
def load_data(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(line.strip())

    sequences = []
    for entry in data:
        split_entry = entry.split(' [SEP] ')
        if (len(split_entry) == 2):
            sequences.append(split_entry)
        else:
            print(f"Skipping invalid entry: {entry}")

    df = pd.DataFrame(sequences, columns=['heavy', 'light'])
    return df



In [None]:
!pip show datasets


In [None]:
# Load training and validation data

train_file_path = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2BERT/data/paired_full_seqs_sep_train_no_ids_small_SPACE_separated.txt'
val_file_path = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/BERT2BERT/data/paired_full_seqs_sep_val_no_ids_small_SPACE_separated.txt'
#test_file_path = '/ibmm_data2/oas_database/paired_lea_tmp/paired_model/train_test_val_datasets/heavy_sep_light_seq/paired_full_seqs_sep_test_no_ids_space_separated_SMALL.txt'

In [None]:
train_df = load_data(train_file_path)
val_df = load_data(val_file_path)
#test_df = load_data(test_file_path)


encoder_max_length = 200
decoder_max_length = 200

def process_data_to_model_inputs(batch):
    # tokenize the inputs and labels
    inputs = tokenizer(batch["light"], padding="max_length", truncation=True, max_length=encoder_max_length)
    outputs = tokenizer(batch["heavy"], padding="max_length", truncation=True, max_length=decoder_max_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    #batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()

    # Ignore PAD token in the labels
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

    return batch

In [None]:
# Convert the dataframes to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df[['heavy', 'light']])
val_dataset = Dataset.from_pandas(val_df[['heavy', 'light']])
#test_dataset = Dataset.from_pandas(test_df[['heavy', 'light']])


train_data = train_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
)

# "decoder_input_ids",
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_attention_mask", "labels"],
)

val_data = val_dataset.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
)

# "decoder_input_ids",
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_attention_mask", "labels"],
)


# test_data = test_dataset.map(
#     process_data_to_model_inputs,   
#     batched=True,
#     batch_size=batch_size,
# )   

# # "decoder_input_ids",
# test_data.set_format(
#     type="torch", columns=["input_ids", "attention_mask", "decoder_attention_mask", "labels"],
# )




# print heavy and light seq from the first example in the training data (train_dataset)
print(f"first example heavy and light seq {train_dataset[0]}, {train_dataset[1]}")


# Initialize the trainer
trainer = Seq2SeqAdapterTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    adapter_names=["seq2seq_adapter"],
)


In [None]:
model.generation_config

In [None]:
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
#print(f"trainer.get_train_dataloader().collate_fn: {trainer.get_train_dataloader().collate_fn}")

# Train the model
trainer.train()
#trainer.evaluate()




In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
input_prompt = "S T G V A F M E I N G L R S D D T A T Y F C A I N R V G D R G S N P S Y F Q D W G Q G T R V T V S S "
print(f"input_prompt: {input_prompt}")

inputs = tokenizer(input_prompt, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)

print(f"attention_mask: {attention_mask}")

#input_ids = tokenizer.encode(input_prompt, return_tensors="pt").to(device)
print(f"input_ids: {input_ids}")

# Generate text using the model
generated_seq = model.generate(input_ids=input_ids, 
                               attention_mask=attention_mask, 
                               max_length=100, 
                               output_scores=True, 
                               return_dict_in_generate=True)

# Turn output scores to probabilities
# generated_seq_probs = torch.nn.functional.softmax(generated_seq['scores'][0], dim=-1)

# Access the first element in the generated sequence
sequence = generated_seq["sequences"][0]

# Print the generated sequences and probabilities
print(f"encoded heavy sequence: {sequence}.")

# Convert the generated IDs back to text
generated_text = tokenizer.decode(sequence, skip_special_tokens=True)

print("decoded heavy sequence: ", generated_text)

# print(test_data)

# Load your test data
test_file_path = '/kaggle/input/test-file/paired_full_seqs_sep_test_no_ids_space_separated_SMALL.txt'
test_df = load_data(test_file_path)


# extract the light sequences from test_df
light_sequences = test_df["light"]

print("light_sequences: ", light_sequences)
print(f"length of light sequences {len(light_sequences)}")

generated_heavy_seqs = []

# Iterate through each sequence in the test dataset
for i in range(50):
    inputs = tokenizer(light_sequences[i], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    generated_seq = model.generate(input_ids=input_ids, 
                               attention_mask=attention_mask, 
                               max_length=100, 
                               output_scores=True, 
                               return_dict_in_generate=True,
                                   generation_config=generation_config)
    
    # Access the first element in the generated sequence
    sequence = generated_seq["sequences"][0]

    # Print the generated sequences and probabilities
    print(f"encoded heavy sequence: {sequence}.")

    # Convert the generated IDs back to text
    generated_text = tokenizer.decode(sequence, skip_special_tokens=True)

    print("decoded heavy sequence: ", generated_text)

    generated_heavy_seqs.append(generated_text)


print("generated_heavy_seqs:")
# print each generated sequence on new line
for seq in generated_heavy_seqs:
    print(seq)