In [None]:
import os
os.environ['WANDB_NOTEBOOK_NAME'] = 'football_segmentation_training.ipynb'

In [None]:
from datasets import load_from_disk

dataset = load_from_disk("Football/dataset")


In [None]:
from pycocotools.coco import COCO
annotation_file = 'Football/COCO_Football Pixel.json'
coco = COCO(annotation_file)
id2label = { i: v['name'] for i, (_, v) in enumerate(coco.cats.items())}
label2id = { v: k for k, v in id2label.items()}
id2label

In [None]:
import numpy as np

def freq_weighted_iou(pred, target, class_num):
    ious = []
    for i in range(class_num):
        pred_i = (pred == i)
        target_i = (target == i)
        intersection = np.sum(pred_i & target_i, axis=(1, 2))
        union = np.sum(pred_i | target_i, axis=(1, 2))
        iou = (intersection + 1e-6) / (union + 1e-6)
        ious.append(iou)
    
    ious = np.stack(ious, axis=1)
    
    # Calculate the frequency of each class in the target
    freq = np.sum(target == np.arange(class_num)[:, None, None, None], axis=(1, 2, 3))
    total_pixels = np.sum(freq)
    freq = freq / total_pixels

    # Calculate the frequency-weighted IoU
    freq_weighted_iou = np.sum(ious * freq, axis=1)
    
    return freq_weighted_iou.mean()

In [None]:
from transformers import Trainer, TrainingArguments, AutoModelForSemanticSegmentation, AutoImageProcessor, BeitForSemanticSegmentation
import evaluate

image_processor = AutoImageProcessor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
#model = AutoModelForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640', id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)


In [None]:
split = dataset.train_test_split(test_size=0.1)
train_ds, val_ds = split['train'], split['test']
print(len(train_ds), len(val_ds))

def transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["mask"]]
    inputs = image_processor(images)
    inputs['labels'] = np.array(labels, dtype=np.int64)
    return inputs

train_ds.set_transform(transforms)
val_ds.set_transform(transforms)

In [None]:
id2label

In [None]:
def compute_weight_in_ds(ds):
    freqs = np.zeros(len(id2label))
    for i in range(len(ds)):
        target = ds[i]['labels']
        for i in range(len(id2label)):
            freqs[i] += np.sum(target == i)
    return freqs

class_counts = compute_weight_in_ds(train_ds)
class_counts /= len(train_ds)

class_counts[class_counts == 0] = 1
print("Class Counts:", class_counts)


class_weights = train_ds[0]['labels'].shape[0] * train_ds[0]['labels'].shape[1] / class_counts
class_weights = np.array([4, 8, 1, 1, 16, 16, 1, 4, 8, 2, 2])
print("Class Weights (Inverse Frequency):", class_weights)

In [None]:
from torch.nn import CrossEntropyLoss
import torch
import torch.nn as nn
weight = torch.tensor(class_weights, device='cuda', dtype=torch.float)

class BeitWithWeightedCrossEntropy(BeitForSemanticSegmentation):
    def compute_loss(self, logits, auxiliary_logits, labels):
        # upsample logits to the images' original size
        upsampled_logits = nn.functional.interpolate(
            logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
        )
        if auxiliary_logits is not None:
            upsampled_auxiliary_logits = nn.functional.interpolate(
                auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
        # compute weighted loss
        loss_fct = CrossEntropyLoss(weight=weight, ignore_index=self.config.semantic_loss_ignore_index)
        main_loss = loss_fct(upsampled_logits, labels)
        loss = main_loss
        if auxiliary_logits is not None:
            auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
            loss += self.config.auxiliary_loss_weight * auxiliary_loss

        return loss

model = BeitWithWeightedCrossEntropy.from_pretrained('microsoft/beit-base-finetuned-ade-640-640', id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)


In [None]:

import torch

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if isinstance(value, np.ndarray):
                metrics[key] = value.tolist()
        metrics['freq_weighted_iou'] = freq_weighted_iou(pred_labels, labels, len(id2label))
                
        return metrics
    
training_args = TrainingArguments(
    output_dir="models",
    learning_rate=6e-5,
    num_train_epochs=15,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    save_total_limit=3,
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model="mean_iou",
    evaluation_strategy="steps",
    report_to="wandb",
    logging_strategy="steps",
    log_level="error",
    dataloader_drop_last=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)


In [None]:
trainer.train()

In [None]:
import wandb
wandb.finish()