In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

##import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install evaluate

In [None]:

import torch
from transformers import VivitForVideoClassification, VivitImageProcessor
import imageio
import cv2
import numpy as np
import os
import time

In [None]:
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")


In [None]:
import  torchvision.transforms.v2 
class ToPILImage:
    def __init__(self):
        self.to_pil = torchvision.transforms.ToPILImage()
    def __call__(self, frames):
        print(len(frames))
        frames = [self.to_pil(frame) for frame in frames]
        return frames
class Resize:
    def __init__(self, h,w):
        self.h = h
        self.w = w  
        self.resize = torchvision.transforms.v2.Resize((h,w))
    def __call__(self, frames):
        frames = [self.resize(frame) for frame in frames]
        return frames

class ToTensor:
    def __init__(self):
        self.to_tensor = torchvision.transforms.v2.ToTensor()
    def __call__(self, frames):
        frames = [self.to_tensor(frame) for frame in frames]
        return frames
    
class SampleFrames:
    def __init__(self, num_frames=32):
        self.num_frames = num_frames

    def __call__(self, frames):
        frames = torch.stack(frames)
        if frames.shape[0] < self.num_frames:
            # Repeat frames until the desired number is reached
            repeated_frames = frames.repeat((self.num_frames // frames.shape[0], 1, 1, 1))
            remainder = self.num_frames % frames.shape[0]
            if remainder > 0:
                repeated_frames = torch.cat((repeated_frames, frames[:remainder]), dim=0)
            return repeated_frames
        else:
            # Use UniformTemporalSubsample to subsample frames
            print("Subsampling frames")
            subsample = torchvision.transforms.v2.UniformTemporalSubsample(self.num_frames)
            return subsample(frames.unsqueeze(0)).squeeze(0)
        
class Normalize:
    def __init__(self, mean, std):
                self.normalize = torchvision.transforms.v2.Normalize(mean=mean, std=std)
    def __call__(self, frames):
                frames = [self.normalize(frame) for frame in frames]
                return frames

In [None]:
from torchvision.datasets import DatasetFolder
from torchvision.transforms.v2 import Compose

# Define a custom loader for GIF files
def gif_loader(path):
    frames = []
    gif_frames =imageio.mimread(path, memtest=False)
    for frame in gif_frames:
                        # Check if the frame has 4 channels (RGBA)
                        if frame.shape[-1] == 4:
                            # Convert RGBA to RGB
                            frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
                        frames.append(frame)
    return frames#np.array(frames)#.transpose(0, 3, 1, 2)  # Convert to (frames, channels, height, width)
# Define a transform to apply to each frame of the GIF
transform = Compose([
    
    #
    ToPILImage(),  # Convert frames to PIL images
    Resize(224, 224),  # Resize frames to 224x224
    ToTensor(),
    SampleFrames(),
    #ToTensor(),          # Convert frames to PyTorch tensors
    
])

# Create a custom dataset using DatasetFolder
train_dataset = DatasetFolder(
    root="/kaggle/input/exercise-class/kinetics1/train",
    loader=gif_loader,
    extensions=("gif",),
    transform=transform  # Use the existing transform variable
)


test_dataset = DatasetFolder(
    root="/kaggle/input/exercise-class/kinetics1/test",
    loader=gif_loader,
    extensions=("gif",),
    transform=transform  # Use the existing transform variable
)

print(f"Custom dataset created with {len(train_dataset)} samples.")
print(f"Test dataset created with {len(test_dataset)} samples.")

In [None]:
train_dataset.classes

In [None]:
from torch.utils.data import DataLoader

# Create a DataLoader for the custom_dataset
batch_size = 2  # Define the batch size
shuffle = True   # Shuffle the data
num_workers = 1 # Number of subprocesses for data loading

data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)

# Example: Iterate through the DataLoader
for batch_idx, (data, labels) in enumerate(data_loader):
    print(f"Batch {batch_idx}:")
    print(f"Data shape: {len(data)} {data[0].shape}")  # Print shape of the first item in the batch
    print(f"Labels: {labels}")
    break  # Process only the first batch for demonstration

In [None]:
fsdp_config = {
    "compute_environment": "LOCAL_MACHINE",
    "debug": False,
    "distributed_type": "FSDP",
    "downcast_bf16": "no",
    "fsdp_config": {
        "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
        "fsdp_backward_prefetch_policy": "BACKWARD_PRE",
        "fsdp_forward_prefetch": False,
        "fsdp_cpu_ram_efficient_loading": True,
        "fsdp_offload_params": False,
        "fsdp_sharding_strategy": "FULL_SHARD",
        "fsdp_state_dict_type": "SHARDED_STATE_DICT",
        "fsdp_sync_module_states": True,
        "fsdp_transformer_layer_cls_to_wrap": "BertLayer",
        "fsdp_use_orig_params": True
    },
    "machine_rank": 0,
    "main_training_function": "main",
    "mixed_precision": "bf16",
    "num_machines": 1,
    "num_processes": 2,
    "rdzv_backend": "static",
    "same_network": True,
    "tpu_env": [],
    "tpu_use_cluster": False,
    "tpu_use_sudo": False,
    "use_cpu": False,
}

In [None]:
from transformers import TrainingArguments, Trainer
model_name ='vivit'
new_model_name = f"{model_name}-finetuned-workouts"
num_epochs = 10

args = TrainingArguments(
    model_name,
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=5, gradient_accumulation_steps=4,
    per_device_eval_batch_size=5,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    max_steps=(len(train_dataset)// 5) * num_epochs,
        report_to="none"

)

In [None]:
import evaluate

metric = evaluate.load("accuracy")

In [None]:
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions."""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    print(labels)
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
from transformers import default_data_collator

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
     tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    
 )

In [None]:
train_results = trainer.train()