# TRAIN

prepare the dataset

In [None]:
import json

with open("reddit_data.json", "r") as f:
    data = json.load(f)

for item in data:
    if isinstance(item.get("summary"), dict):
        item["summary"] = json.dumps(item["summary"], ensure_ascii=False)  
    elif item.get("summary") is None:
        item["summary"] = ""  

with open("reddit_data_cleaned.json", "w") as f:
    json.dump(data, f, indent=4, ensure_ascii=False)

Generating train split: 100 examples [00:00, 7803.79 examples/s]

DatasetDict({
    train: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'summary'],
        num_rows: 100
    })
})





load the dataset

In [2]:
from datasets import load_dataset

dataset = load_dataset("json", data_files="reddit_data_cleaned.json")
print(dataset)


  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'summary'],
        num_rows: 100
    })
})


preprocess the dataset

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

def preprocess_dataset(examples):
    # all_text = examples["selftext"] + "\n" + "\n".join(examples["comments"])
    all_text = [selftext + "\n" + "\n".join(comments) for selftext, comments in zip(examples["selftext"], examples["comments"])]
    model_inputs = tokenizer(all_text, max_length=1024, truncation=True, padding=True)
    labels = tokenizer(examples["summary"], max_length=1024, truncation=True, padding=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_dataset, batched=True)

split the dataset

In [4]:
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split

train_df = tokenized_dataset["train"].to_pandas()

train_split, val_split = train_test_split(train_df, test_size=0.2)

train_dataset = Dataset.from_pandas(train_split)
val_dataset = Dataset.from_pandas(val_split)

dataset2 = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

print(dataset2)

DatasetDict({
    train: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'summary', 'input_ids', 'attention_mask', 'labels', '__index_level_0__'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'summary', 'input_ids', 'attention_mask', 'labels', '__index_level_0__'],
        num_rows: 20
    })
})


In [5]:
print(dataset2)

DatasetDict({
    train: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'summary', 'input_ids', 'attention_mask', 'labels', '__index_level_0__'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'summary', 'input_ids', 'attention_mask', 'labels', '__index_level_0__'],
        num_rows: 20
    })
})


since the validation set has summaries in it, lets remove it

In [6]:
def remove_summary_from_validation(example):
    if 'summary' in example:
        del example['summary']
    return example

dataset2['validation'] = dataset2['validation'].map(remove_summary_from_validation)
print(dataset2)

Map: 100%|██████████| 20/20 [00:00<00:00, 2449.66 examples/s]

DatasetDict({
    train: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'summary', 'input_ids', 'attention_mask', 'labels', '__index_level_0__'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['title', 'selftext', 'comments', 'index', 'input_ids', 'attention_mask', 'labels', '__index_level_0__'],
        num_rows: 20
    })
})





prepare training stuff

In [7]:
from transformers import AutoModelForSeq2SeqLM, TrainingArguments, Trainer

model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

training_args = TrainingArguments(
    output_dir="../results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=12,
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset2["train"],
    eval_dataset=dataset2["validation"]
)

training time

In [8]:
import torch
torch.cuda.empty_cache()

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,1.969823
2,No log,0.929973
3,No log,0.731302
4,No log,0.681642
5,No log,0.684628
6,No log,0.711838
7,No log,0.730015
