In [17]:
from torchvision.models import ResNet50_Weights

import src.utils as utils
import os
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split


In [2]:
df = utils.build_dataset('data/plantvillage/plantvillage dataset')
df

Unnamed: 0,Format,Species,Healthy,Disease,Folder,FileName
0,color,Apple,False,Apple_scab,data/plantvillage/plantvillage dataset\color\A...,00075aa8-d81a-4184-8541-b692b78d398a___FREC_Sc...
1,color,Apple,False,Apple_scab,data/plantvillage/plantvillage dataset\color\A...,01a66316-0e98-4d3b-a56f-d78752cd043f___FREC_Sc...
2,color,Apple,False,Apple_scab,data/plantvillage/plantvillage dataset\color\A...,01f3deaa-6143-4b6c-9c22-620a46d8be04___FREC_Sc...
3,color,Apple,False,Apple_scab,data/plantvillage/plantvillage dataset\color\A...,0208f4eb-45a4-4399-904e-989ac2c6257c___FREC_Sc...
4,color,Apple,False,Apple_scab,data/plantvillage/plantvillage dataset\color\A...,023123cb-7b69-4c9f-a521-766d7c8543bb___FREC_Sc...
...,...,...,...,...,...,...
162911,segmented,Tomato,False,Tomato_Yellow_Leaf_Curl_Virus,data/plantvillage/plantvillage dataset\segment...,ffb295c9-f14e-4a15-831a-bf905da7fcb6___UF.GRC_...
162912,segmented,Tomato,False,Tomato_Yellow_Leaf_Curl_Virus,data/plantvillage/plantvillage dataset\segment...,ffe08ccc-c55e-4ca2-9234-2906b98b8d05___YLCV_NR...
162913,segmented,Tomato,False,Tomato_Yellow_Leaf_Curl_Virus,data/plantvillage/plantvillage dataset\segment...,ffe996e5-c8dc-47b7-bca2-4fc25e5ac57c___UF.GRC_...
162914,segmented,Tomato,False,Tomato_Yellow_Leaf_Curl_Virus,data/plantvillage/plantvillage dataset\segment...,fff42f1b-7ec4-46e3-9269-45932e63635e___YLCV_GC...


In [12]:
class PlantDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = os.path.join(self.df.iloc[idx]['Folder'], self.df.iloc[idx]['FileName'])
        image = Image.open(img_path).convert('RGB')
        label = 1 if self.df.iloc[idx]['Healthy'] else 0  # binary classification

        if self.transform:
            image = self.transform(image)
        return image, label


In [13]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [14]:
df = df[df['Healthy'].notna()]  # eliminamos filas sin etiqueta

train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['Healthy'], random_state=42)

train_dataset = PlantDataset(train_df, transform=train_transform)
val_dataset = PlantDataset(val_df, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [15]:
device = "cpu"
if torch.cuda.is_available():
  device = "cuda:0"

device

'cuda:0'

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet50(weights=ResNet50_Weights.DEFAULT)

# Reemplazamos la última capa para clasificación binaria
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)

model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\chris/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 37.6MB/s]


In [None]:
criterion = nn.BCEWithLogitsLoss()  # para clasificación binaria
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(5):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")
