<a href="https://colab.research.google.com/github/eisbetterthanpi/vision/blob/main/meta_pseudo_labels_gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# @title data
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
labeled_size = int(0.1 * len(train_dataset))  # Adjust the labeled percentage as needed
unlabeled_size = len(train_dataset) - labeled_size
labeled_data, unlabeled_data = random_split(train_dataset, [labeled_size, unlabeled_size])

batch_size=64
labeled_loader = DataLoader(labeled_data, batch_size=batch_size, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


In [27]:
# @title model
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TeacherCNN(nn.Module):
    def __init__(self):
        super(TeacherCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.lin = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv(x)
        x=x.flatten(1) # x = x.view(x.size(0), -1)
        x = self.lin(x)
        return x

class StudentCNN(nn.Module):
    def __init__(self):
        super(StudentCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc1 = nn.Linear(32 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv(x)
        x=x.flatten(1) # x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

teacher_model = TeacherCNN().to(device)
student_model = StudentCNN().to(device)



In [21]:
# @title train/test
# from tqdm import tqdm

def train(teacher_model, student_model, labeled_loader, unlabeled_loader, criterion, teacher_optimizer, student_optimizer, alpha, device):
    teacher_model.train()
    student_model.train()
    total_loss = 0

    # for labeled_batch, unlabeled_batch in tqdm(zip(labeled_loader, unlabeled_loader), total=min(len(labeled_loader), len(unlabeled_loader))):
    for labeled_batch, unlabeled_batch in zip(labeled_loader, unlabeled_loader):
        # Train teacher model on labeled data
        inputs, labels = labeled_batch
        inputs, labels = inputs.to(device), labels.to(device)
        teacher_optimizer.zero_grad()
        teacher_outputs = teacher_model(inputs)
        teacher_loss = criterion(teacher_outputs, labels)
        teacher_loss.backward()
        teacher_optimizer.step()

        # Use teacher model to generate pseudo-labels for unlabeled data
        unlabeled_inputs, _ = unlabeled_batch
        unlabeled_inputs = unlabeled_inputs.to(device)
        with torch.no_grad():
            # pseudo_labels = torch.argmax(teacher_model(unlabeled_inputs), dim=1)
            pseudo_labels = teacher_model(unlabeled_inputs) # soft_pseudo_labels

        # Train student model on both labeled and unlabeled data
        student_optimizer.zero_grad()
        labeled_outputs = student_model(inputs)
        pseudo_outputs = student_model(unlabeled_inputs)
        labeled_loss = criterion(labeled_outputs, labels)
        pseudo_loss = criterion(pseudo_outputs, pseudo_labels)
        total_loss = labeled_loss + alpha * pseudo_loss
        total_loss.backward()
        student_optimizer.step()
    return total_loss.item()

def test(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, total=len(dataloader)):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_samples += labels.size(0)
    accuracy = correct_predictions / total_samples
    average_loss = total_loss / len(dataloader)
    return accuracy, average_loss


In [26]:
# @title run
import torch.optim as optim

# def train_and_evaluate(teacher_model, student_model, labeled_loader, unlabeled_loader, test_loader, num_epochs, alpha, device):
criterion = nn.CrossEntropyLoss()
teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.0001)
student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)

for epoch in range(20):
    teacher_loss = train(teacher_model, student_model, labeled_loader, unlabeled_loader, criterion, teacher_optimizer, student_optimizer, alpha, device)
    test_accuracy, test_loss = test(student_model, test_loader, criterion, device)
    print(f'{epoch + 1} Teacher Loss: {teacher_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, Test Loss: {test_loss:.4f}')



100%|██████████| 157/157 [00:04<00:00, 34.65it/s]


1 Teacher Loss: -59.7059, Test Accuracy: 0.5758, Test Loss: 9.0380


100%|██████████| 157/157 [00:04<00:00, 35.16it/s]


2 Teacher Loss: -3087.1494, Test Accuracy: 0.0974, Test Loss: 2150.6514


100%|██████████| 157/157 [00:04<00:00, 33.73it/s]


3 Teacher Loss: -19515.1758, Test Accuracy: 0.0974, Test Loss: 9577.8751


KeyboardInterrupt: 