In [71]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from torchvision.models import ResNet18_Weights
from PIL import Image
import pandas as pd
from pathlib import Path
from tqdm import tqdm

In [72]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

FRAME_DIR = Path("../data/frames")
LABEL_PATH = Path("../data/labels/labels_task1.csv")
BATCH_SIZE = 2
NUM_WORKERS = 2
EPOCHS = 10
NUM_CLASSES = 4
SEQUENCE_LENGTH = 16
IMAGE_SIZE = 224

In [73]:
# Carregar labels
df = pd.read_csv(LABEL_PATH)

# Verificar vídeos com frames extraídos
available_videos = {p.name for p in FRAME_DIR.iterdir() if p.is_dir() and any(p.glob("*.jpg"))}
df = df[df["VIDEO"].isin(available_videos)].reset_index(drop=True)

print(f"Vídeos disponíveis com frames: {len(df)}")

# Guardar labels filtradas temporariamente
df.to_csv("../data/labels/filtered_labels_task1.csv", index=False)

Vídeos disponíveis com frames: 30


In [74]:
class GRSDataset(Dataset):
    def __init__(self, df, frame_dir, transform=None, sequence_length=16):
        self.data = df
        self.frame_dir = Path(frame_dir)
        self.transform = transform
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        video_id = row["VIDEO"]
        label = row["GRS"]
        frame_path = self.frame_dir / video_id

        # Carregar lista de frames
        frames = sorted(list(frame_path.glob("*.jpg")))
        if len(frames) == 0:
            raise IndexError(f"Nenhum frame encontrado para o vídeo {video_id} em {frame_path}")

        selected = frames[:self.sequence_length]
        while len(selected) < self.sequence_length:
            selected.append(selected[-1])  # repetir último frame

        images = [self.transform(Image.open(f).convert("RGB")) for f in selected]
        images = torch.stack(images)

        return images, label

In [75]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = GRSDataset(df, FRAME_DIR, transform, sequence_length=SEQUENCE_LENGTH)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [76]:
class GRSClassifier(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.cnn = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.cnn.fc = nn.Identity()
        self.lstm = nn.LSTM(input_size=512, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feats = self.cnn(x)
        feats = feats.view(B, T, 512)
        lstm_out, _ = self.lstm(feats)
        out = self.fc(lstm_out[:, -1])
        return out


In [77]:
model = GRSClassifier(NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for inputs, labels in tqdm(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {total_loss:.4f} | Acc: {acc:.4f}")

100%|██████████| 15/15 [00:19<00:00,  1.30s/it]


Epoch 1/10 | Loss: 19.8555 | Acc: 0.4000


100%|██████████| 15/15 [00:19<00:00,  1.30s/it]


Epoch 2/10 | Loss: 11.9393 | Acc: 0.8667


100%|██████████| 15/15 [00:18<00:00,  1.23s/it]


Epoch 3/10 | Loss: 7.1890 | Acc: 0.9667


100%|██████████| 15/15 [00:20<00:00,  1.36s/it]


Epoch 4/10 | Loss: 4.0895 | Acc: 1.0000


100%|██████████| 15/15 [00:18<00:00,  1.25s/it]


Epoch 5/10 | Loss: 2.8734 | Acc: 0.9667


100%|██████████| 15/15 [00:19<00:00,  1.30s/it]


Epoch 6/10 | Loss: 1.8131 | Acc: 1.0000


100%|██████████| 15/15 [00:19<00:00,  1.29s/it]


Epoch 7/10 | Loss: 1.3714 | Acc: 1.0000


100%|██████████| 15/15 [00:19<00:00,  1.33s/it]


Epoch 8/10 | Loss: 0.9774 | Acc: 1.0000


100%|██████████| 15/15 [00:19<00:00,  1.29s/it]


Epoch 9/10 | Loss: 0.8520 | Acc: 1.0000


100%|██████████| 15/15 [00:17<00:00,  1.20s/it]

Epoch 10/10 | Loss: 0.6703 | Acc: 1.0000





In [78]:
torch.save(model.state_dict(), "../outputs/models/grs_classifier.pt")
print("✅ Modelo salvo com sucesso.")

✅ Modelo salvo com sucesso.


In [79]:
sample = next(iter(train_loader))
x, y = sample
with torch.no_grad():
    model.eval()
    pred = model(x.to(device))
    print("Previsão:", pred.argmax(dim=1).cpu().numpy())
    print("Label real:", y.numpy())

Previsão: [0 2]
Label real: [0 2]
