In [7]:
import os
import torch
import pandas as pd
from PIL import Image
import torch.nn as nn
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics.classification import Accuracy
from torch.utils.data import Dataset, DataLoader

In [None]:
class DeiTSmallLightningModel(pl.LightningModule):
    def __init__(self, num_classes: int, lr: float = 1e-3):
        super().__init__()
        self.save_hyperparameters()

        # Load DeiT-Small from Facebook's official repo via torch.hub
        self.model = torch.hub.load('facebookresearch/deit:main', 'deit_small_patch16_224', pretrained=True)

        # Freeze the transformer blocks (feature extractor)
        for param in self.model.blocks.parameters():
            param.requires_grad = False

        # Replace the classification head
        in_features = self.model.head.in_features
        self.model.head = nn.Linear(in_features, num_classes)

        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = self.train_acc(logits, y)

        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = self.val_acc(logits, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

In [9]:
class CustomImageDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, mode="train"):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.mode = mode

        self.crop_suffixes = [
            "_cropped_bottom_right_bright.png",
            "_cropped_bottom_left_bright.png",
        ]

        self.samples = []
        for _, row in self.df.iterrows():
            file_id = os.path.splitext(row["path"])[0]
            fold = row["fold"]
            label = int(row["class_numeric"])
            for suffix in self.crop_suffixes:
                img_path = os.path.join(self.img_dir, f"fold_{fold}", f"{file_id}{suffix}")
                if os.path.exists(img_path):
                    self.samples.append((img_path, label))
                else:
                    print(f"Image not found: {img_path}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new("RGB", (224, 224), (0, 0, 0))
        if self.transform:
            img = self.transform(img)
        return img, label

In [10]:
class CustomImageDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, mode="train"):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.mode = mode

        self.crop_suffixes = [
            "_cropped_bottom_right_bright.png",
            "_cropped_bottom_left_bright.png",
        ]

        self.samples = []
        for _, row in self.df.iterrows():
            file_id = os.path.splitext(row["path"])[0]
            fold = row["fold"]
            label = int(row["class_numeric"])
            for suffix in self.crop_suffixes:
                img_path = os.path.join(self.img_dir, f"fold_{fold}", f"{file_id}{suffix}")
                if os.path.exists(img_path):
                    self.samples.append((img_path, label))
                else:
                    print(f"Image not found: {img_path}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            img = Image.new("RGB", (224, 224), (0, 0, 0))
        if self.transform:
            img = self.transform(img)
        return img, label



class DataModule(pl.LightningDataModule):
    def __init__(self, csv_path, image_dir, fold_val=0, batch_size=32):
        super().__init__()
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.fold_val = fold_val
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        self.train_df = self.df[self.df["fold"] != self.fold_val]
        self.val_df = self.df[self.df["fold"] == self.fold_val]
        self.train_dataset = CustomImageDataset(self.train_df, self.image_dir, transform=self.transform, mode="train")
        self.val_dataset = CustomImageDataset(self.val_df, self.image_dir, transform=self.transform, mode="val")
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)


In [11]:
# if __name__ == "__main__":
#     csv_path = "./data_with_folds.csv"
#     image_dir = ""
#     fold_val = 0
#     batch_size = 32
#     num_classes = len(pd.read_csv(csv_path)["class_numeric"].unique())

#     data_module = DataModule(csv_path,
#                             image_dir,
#                             fold_val=fold_val,
#                             batch_size=batch_size
#                             )
    
#     model = DeiTSmallLightningModel(num_classes=num_classes)

#     trainer = pl.Trainer(
#         max_epochs=5,
#         accelerator="cuda" if torch.cuda.is_available() else "cpu",
#         devices="auto",
#         precision="16-mixed"
#     )
    
#     trainer.fit(model, data_module)

In [12]:
import os

if __name__ == "__main__":
    csv_path = "./data_with_folds.csv"
    image_dir = ""
    fold_val = 0
    batch_size = 32
    model_name = "deit_small"
    weights_dir = "weights"
    os.makedirs(weights_dir, exist_ok=True)

    classifications_dir = "classifications"
    os.makedirs(classifications_dir, exist_ok=True)

    # Set number of classes from CSV
    df = pd.read_csv(csv_path)
    num_classes = df["class_numeric"].nunique()

    # Initialize data module and model
    data_module = DataModule(csv_path, image_dir, fold_val=fold_val, batch_size=batch_size)
    model = DeiTSmallLightningModel(num_classes=num_classes)

    # Define trainer with checkpointing
    checkpoint_path = os.path.join(weights_dir, f"{model_name}.ckpt")
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="cuda" if torch.cuda.is_available() else "cpu",
        devices="auto",
        precision="16-mixed",
        default_root_dir=weights_dir,
        callbacks=[pl.callbacks.ModelCheckpoint(
            dirpath=weights_dir,
            filename=model_name,
            save_top_k=1,
            monitor="val_acc",
            mode="max"
        )]
    )

    # Train the model
    trainer.fit(model, data_module)

    # Load best model for prediction
    best_model = DeiTSmallLightningModel.load_from_checkpoint(checkpoint_path, num_classes=num_classes)
    best_model.eval()
    best_model.freeze()

    # Inference on full dataset
    full_dataset = CustomImageDataset(df, image_dir, transform=data_module.transform, mode="val")
    full_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=False)

    predictions = []
    paths = []
    for imgs, _ in full_loader:
        with torch.no_grad():
            preds = best_model(imgs.to(best_model.device))
            cls = preds.argmax(dim=1).cpu().tolist()
            predictions.extend(cls)

    # Construct prediction mapping
    flat_paths = [sample[0] for sample in full_dataset.samples]
    file_names = [os.path.basename(p).split('_')[0] + '.png' for p in flat_paths]

    # Add predictions to original df
    pred_df = pd.DataFrame({
        'path': file_names,
        f'cls_{model_name}': predictions
    })

    # Remove duplicates in case of augmentations
    pred_df = pred_df.groupby('path').agg(lambda x: x.mode()[0]).reset_index()

    merged_df = df.merge(pred_df, on="path", how="left")
    output_csv_path = os.path.join(classifications_dir, f"cls_{model_name}.csv")
    merged_df.to_csv(output_csv_path, index=False)

    print(f"Saved predictions to: {output_csv_path}")


Using cache found in C:\Users\andre/.cache\torch\hub\facebookresearch_deit_main
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\andre\anaconda3\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory C:\Users\andre\1JUPYTER\Raia\Odonto\OsteoporosisDetection-dev\fine_tuning\cnn\weights exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | VisionTransformer  | 21.7 M | train
1 | criterion | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
373 K     Trainable params
21.3 M    Non-trainable params
21.7 M    Total params
86.667    Total estimated model 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\andre\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\andre\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
Using cache found in C:\Users\andre/.cache\torch\hub\facebookresearch_deit_main


Saved predictions to: classifications\cls_deit_small.csv


In [14]:
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score, f1_score

# Load the CSV
df = pd.read_csv("C:/Users/andre/1JUPYTER/Raia/Odonto/OsteoporosisDetection-dev/fine_tuning/cnn/classifications/cls_deit_small.csv")

# Filter for fold == 0
df_fold_0 = df[df["fold"] == 0]

# Extract true and predicted labels
y_true = df_fold_0["class_numeric"]
y_pred = df_fold_0["cls_deit_small"]

# Compute metrics
conf_matrix = confusion_matrix(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
accuracy = accuracy_score(y_true, y_pred)

# Print results
print("Confusion Matrix:\n", conf_matrix)
print("Precision (weighted):", precision)
print("Recall (weighted):", recall)
print("Accuracy:", accuracy)
print("F1 Score (weighted):", f1_score(y_true, y_pred, average='weighted', zero_division=0))

Confusion Matrix:
 [[98  8  6]
 [41 33  9]
 [ 8 12 16]]
Precision (weighted): 0.6273876931211076
Recall (weighted): 0.6363636363636364
Accuracy: 0.6363636363636364
F1 Score (weighted): 0.6157151444684368
