<a href="https://colab.research.google.com/github/berfingundem/BladderCancer/blob/main/MedViTCode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Gerekli kütüphaneler
import os, shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 2. Google Drive bağlantısı
from google.colab import drive
drive.mount('/content/drive')

# 3. Veri dizini ve sınıflar
data_dir = "/content/drive/MyDrive/bladder_set"
class_names = ['HGC', 'LGC', 'NST', 'NTL']
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)

# 4. Eğitim ve doğrulama için klasörleri oluştur
for class_name in class_names:
    class_path = os.path.join(data_dir, class_name)
    images = os.listdir(class_path)
    train_images, val_images = train_test_split(images, test_size=0.2, random_state=42)
    os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
    os.makedirs(os.path.join(val_dir, class_name), exist_ok=True)
    for img in train_images:
        shutil.copy(os.path.join(class_path, img), os.path.join(train_dir, class_name, img))
    for img in val_images:
        shutil.copy(os.path.join(class_path, img), os.path.join(val_dir, class_name, img))

# 5. Görsel dönüşümleri
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# 6. Dataset ve DataLoader
train_dataset = ImageFolder(train_dir, transform=transform)
val_dataset = ImageFolder(val_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
num_classes = len(class_names)

# 7. MidViT Modeli
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
    def forward(self, x):
        return self.block(x)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class MidViT(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.cnn = nn.Sequential(
            ConvBlock(3, 64),
            ConvBlock(64, 128),
            ConvBlock(128, 256)
        )
        self.flatten_size = 28
        self.seq_len = self.flatten_size ** 2
        self.embed_dim = 256
        self.transformer = TransformerBlock(dim=self.embed_dim)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.seq_len * self.embed_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        x = self.cnn(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.transformer(x)
        x = self.classifier(x)
        return x

# 8. Eğitim hazırlığı
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MidViT(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
num_epochs = 14

train_loss_list, val_loss_list = [], []
train_acc_list, val_acc_list = [], []

# 9. Eğitim Döngüsü
for epoch in range(num_epochs):
    model.train()
    train_loss, correct, total = 0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_loss /= len(train_loader)
    train_acc = correct / total
    train_loss_list.append(train_loss)
    train_acc_list.append(train_acc)

    model.eval()
    val_loss, correct, total = 0, 0, 0
    y_true, y_pred, y_score = [], [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            y_score.extend(probs.cpu().numpy())
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        val_loss /= len(val_loader)
        val_acc = correct / total
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)
    print(f"Epoch {epoch+1}/{num_epochs} | Train Acc: {train_acc:.2f} | Val Acc: {val_acc:.2f}")

# 10. Kayıt
torch.save(model.state_dict(), "/content/midvit_model.pt")

# 11. Accuracy & Loss Grafiği
plt.figure(figsize=(12,5), dpi=300)
plt.subplot(1,2,1)
plt.plot(train_acc_list, label="Train Acc")
plt.plot(val_acc_list, label="Val Acc")
plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend()
plt.subplot(1,2,2)
plt.plot(train_loss_list, label="Train Loss")
plt.plot(val_loss_list, label="Val Loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend()
plt.suptitle("Accuracy & Loss Function")
plt.tight_layout(); plt.show()

# 12. Konfüzyon Matrisi
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(6,5), dpi=300)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title("Confusion Matrix")
plt.xlabel("Prediction"); plt.ylabel("Actual")
plt.tight_layout(); plt.show()

# 13. ROC Eğrisi
y_true_bin = label_binarize(y_true, classes=list(range(num_classes)))
y_score = np.array(y_score)
fpr, tpr, roc_auc = {}, {}, {}
for i in range(num_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure(figsize=(8,6), dpi=300)
for i in range(num_classes):
    plt.plot(fpr[i], tpr[i], label=f"{class_names[i]} (AUC = {roc_auc[i]:.2f})")
plt.plot([0,1], [0,1], 'k--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc='lower right')
plt.grid(True)
plt.tight_layout(); plt.show()
