In [22]:
import pandas as pd
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

In [38]:
# Load the merged dataset
training_data = pd.read_csv('video_training_data_10k.csv')

# Keep 10 rows for testing
test_data = training_data.sample(n=10, random_state=42)
train_data = training_data.drop(test_data.index)

In [39]:
training_data

Unnamed: 0,video_id,generated_description,description
0,10046243,ants are eating on the ground ants are eating ...,Ants eating dead insect
1,1005626710,a large orange tori tori tori tori tori tori t...,"Kyoto,japan-sep 4,2017: timelapse of the visit..."
2,1006641379,a man is seen in the middle of a flooded stree...,Circa 1940s - a film about copper mining and s...
3,1006733308,a bed with a pair of black pants and a white s...,Black kit classic menswear. men's accessories ...
4,1007094259,a woman doing yoga in the park a woman doing y...,Young asian woman yoga outdoors keep calm and ...
...,...,...,...
994,9643991,a close up of a curtain with a white backgroun...,White linen cloth on the wind
995,9709337,a person is laying in a hammol a person is lay...,Pan from person's legs resting on a hammock to...
996,9769340,aerial view of the hong skyline and the hong r...,"Bridge to haeundae, south korea, wide shot, la..."
997,9940922,a small dog is playing with a stick a small do...,Dog looking at camera turning head chewing bon...


In [47]:
# Define a custom Dataset class
class VideoDescriptionDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=128, max_target_length=128):
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length
        self.data = data

        self.inputs = []
        self.targets = []

        for _, row in tqdm(data.iterrows(), total=len(data), desc="Processing Data"):
            # Swap `description` and `generated_description` for target and input
            input_enc = tokenizer(row['generated_description'], max_length=max_input_length, padding='max_length', truncation=True, return_tensors='pt')
            target_enc = tokenizer(row['description'], max_length=max_target_length, padding='max_length', truncation=True, return_tensors='pt')
            self.inputs.append(input_enc)
            self.targets.append(target_enc)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if idx >= len(self.inputs):  # Check if idx is out of bounds
            raise IndexError(f"Index {idx} is out of bounds for the dataset.")
        
        item = {key: val.squeeze(0) for key, val in self.inputs[idx].items()}
        item['labels'] = self.targets[idx]['input_ids'].squeeze(0)
        return item

In [48]:
# Initialize the tokenizer and model
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to('cuda')




In [49]:
# Initialize the dataset
train_dataset = VideoDescriptionDataset(train_data, tokenizer)
test_dataset = VideoDescriptionDataset(test_data, tokenizer)

Processing Data: 100%|██████████| 989/989 [00:02<00:00, 455.93it/s]
Processing Data: 100%|██████████| 10/10 [00:00<00:00, 477.46it/s]


In [50]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',  # Directory to save model checkpoints and logs
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=10,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    report_to='none'  # Disable reporting to avoid the deprecated warning
)

In [51]:
# Define data collator
data_collator = lambda data: {
    'input_ids': torch.stack([f['input_ids'] for f in data]),
    'attention_mask': torch.stack([f['attention_mask'] for f in data]),
    'labels': torch.stack([f['labels'] for f in data]),
}

In [52]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

In [53]:
# Train the model
trainer.train()




Epoch,Training Loss,Validation Loss
1,0.8915,0.886428
2,0.7752,0.798386
3,0.8546,0.784755
4,0.6506,0.776618
5,0.5835,0.774487
6,0.7077,0.770232
7,0.7446,0.766309
8,0.8103,0.767289
9,0.7386,0.765439
10,0.7391,0.765402


TrainOutput(global_step=2480, training_loss=0.9222029691742313, metrics={'train_runtime': 210.4722, 'train_samples_per_second': 46.99, 'train_steps_per_second': 11.783, 'total_flos': 334632604139520.0, 'train_loss': 0.9222029691742313, 'epoch': 10.0})

In [54]:
# Generate text for the 10 test rows
model.eval()
test_predictions = []
for batch in tqdm(test_dataset, desc="Generating Predictions"):
    input_ids = batch['input_ids'].unsqueeze(0).to('cuda')
    attention_mask = batch['attention_mask'].unsqueeze(0).to('cuda')
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    test_predictions.append(generated_text)

Generating Predictions: 100%|██████████| 10/10 [00:01<00:00,  8.27it/s]


In [56]:
# Show the output for the 10 rows
for i, (row, prediction) in enumerate(zip(test_data.iterrows(), test_predictions)):
    idx, row_data = row
    print(f"Row {i+1}:")
    print(f"Actual Description: {row_data['description']}")
    print(f"Input Generated Description: {row_data['generated_description']}")
    print(f"Generated Description: {prediction}")
    print()


Row 1:
Actual Description: New york - 5 march, 2020: charming young woman walk on street use phone white animation сloud technology internet networking device online storage computing icon network connection virtual interface
Input Generated Description: young woman using her smartphone in the street a woman is looking at her phone while walking down the street a woman is walking down the street while looking at her phone young woman using her smartphone in the street a woman walking down a street while looking at her phone a woman walking down a street with a cell young woman walking down the street with social icons around her young woman using mobile phone in the street young woman using her smartphone in the street young woman using her smartphone in the street young woman walking down the street with social icons around her young woman walking down the street with social icons around her young woman walking down the street with social icons around her young woman using mobile phon

10k videos training