<a href="https://colab.research.google.com/github/emredeveloper/AI-with-API/blob/main/Knowleadge_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Hyperparametreler
batch_size = 64
learning_rate = 0.001
num_epochs = 5
temperature = 3.0
alpha = 0.7

# MNIST veri seti yükleme ve dönüştürme
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 46.9MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.53MB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 13.4MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.25MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [3]:
# Basit bir öğretmen ve öğrenci modeli tanımlama
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)

class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)


# Modelleri başlat
teacher_model = TeacherModel()
student_model = StudentModel()

In [4]:
# Öğretmen modelini eğitme
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

print("Öğretmen modeli eğitiliyor...")
for epoch in range(num_epochs):
    teacher_model.train()
    for images, labels in train_loader:
        outputs = teacher_model(images)
        loss = criterion(outputs, labels)
        teacher_optimizer.zero_grad()
        loss.backward()
        teacher_optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Öğretmen modelini değerlendirme moduna al
teacher_model.eval()

Öğretmen modeli eğitiliyor...
Epoch [1/5], Loss: 0.2294
Epoch [2/5], Loss: 0.0280
Epoch [3/5], Loss: 0.0367
Epoch [4/5], Loss: 0.0553
Epoch [5/5], Loss: 0.0046


TeacherModel(
  (fc): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [5]:
# Knowledge Distillation için özel distillation kaybı
def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha):
    # Yumuşak ve sert hedefler arasında bir kombinasyon kaybı
    soft_loss = nn.KLDivLoss()(nn.functional.log_softmax(student_logits / temperature, dim=1),
                               nn.functional.softmax(teacher_logits / temperature, dim=1)) * (temperature ** 2)
    hard_loss = nn.CrossEntropyLoss()(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

In [6]:
# Öğrenci modelini eğitme (distillation ile)
student_optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
print("\nÖğrenci modeli distillation ile eğitiliyor...")
for epoch in range(num_epochs):
    student_model.train()
    for images, labels in train_loader:
        with torch.no_grad():
            teacher_logits = teacher_model(images)  # Öğretmen modelin çıktıları
        student_logits = student_model(images)  # Öğrenci modelin çıktıları
        # Distillation kaybını hesapla
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)
        student_optimizer.zero_grad()
        loss.backward()
        student_optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


Öğrenci modeli distillation ile eğitiliyor...




Epoch [1/5], Loss: 0.1248
Epoch [2/5], Loss: 0.0681
Epoch [3/5], Loss: 0.0314
Epoch [4/5], Loss: 0.0438
Epoch [5/5], Loss: 0.0196


In [7]:
# Öğrenci modelinin test doğruluğunu hesapla
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

teacher_accuracy = evaluate_model(teacher_model, test_loader)
student_accuracy = evaluate_model(student_model, test_loader)

print(f"\nÖğretmen Model Test Doğruluğu: {teacher_accuracy:.2f}%")
print(f"Öğrenci Model Test Doğruluğu: {student_accuracy:.2f}%")


Öğretmen Model Test Doğruluğu: 96.88%
Öğrenci Model Test Doğruluğu: 97.12%
