In [1]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.optim as optim
from torchvision.models import mobilenet_v2

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")


In [3]:
from torch.utils.data import Dataset
from PIL import Image
import os

class UnlabeledDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.image_paths = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image  # No label

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(degrees=15),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Load datasets
synthetic_dataset = ImageFolder(root="data/synthetic/cifar10", transform=transform_test)
unlabeled_dataset = UnlabeledDataset("data/real/unlabelled", transform=transform_test)
test_dataset = ImageFolder(root="data/real/animal_data", transform=transform_test)

batch_size = 32
synthetic_loader = DataLoader(synthetic_dataset, batch_size=batch_size, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [5]:
class FeatureExtractor(nn.Module):
    def __init__(self, num_classes):
        super(FeatureExtractor, self).__init__()
        # Load pre-trained MobileNetV2
        mobilenet = mobilenet_v2(pretrained=True)
        
        # Freeze all parameters
        for param in mobilenet.parameters():
            param.requires_grad = False
        
        # Use all layers except the last classifier
        self.features = mobilenet.features
        
        # Add a simple classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1280, num_classes)  # MobileNetV2's last conv layer has 1280 channels
        )

    def forward(self, x):
        x = self.features(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [6]:
num_classes = 3
def load_model(model, path, device):
    model.load_state_dict(torch.load(path, map_location=device))
    print(f"Model loaded from {path}")
    return model

# Load the saved model
model = FeatureExtractor(num_classes=3).to(device)
model = load_model(model, 'model.pth', device)
model.eval()



Model loaded from model.pth


  model.load_state_dict(torch.load(path, map_location=device))


FeatureExtractor(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96,

In [7]:
def generate_pseudo_labels(model, dataloader, threshold=0.8):
    model.eval()  # Set to evaluation mode
    pseudo_data = []
    
    with torch.no_grad():
        for images in dataloader:  # Unlabeled dataset only has images
            images = images.to(device)
            outputs = model(images)  # ✅ No need for 'model.model()'
            probabilities = torch.softmax(outputs, dim=1)
            confidence, pseudo_labels = torch.max(probabilities, dim=1)

            # Store only high-confidence predictions
            for i in range(len(images)):
                if confidence[i] > threshold:
                    pseudo_data.append((images[i].cpu(), pseudo_labels[i].cpu()))  
    
    return pseudo_data



from torch.utils.data import Dataset
import torch

class PseudoLabeledDataset(Dataset):
    def __init__(self, pseudo_data, transform=None):
        self.pseudo_data = pseudo_data
        self.transform = transform

    def __len__(self):
        return len(self.pseudo_data)

    def __getitem__(self, idx):
        image, label = self.pseudo_data[idx]

        # Convert label properly
        label = label.clone().detach().long()

        if self.transform:
            image = self.transform(image)

        return image, label




In [8]:
pseudo_data = generate_pseudo_labels(model, unlabeled_loader, threshold=0.8)
pseudo_dataset = PseudoLabeledDataset(pseudo_data, transform=None)  # Avoid double transforms

# Combine datasets
combined_dataset = ConcatDataset([synthetic_dataset, pseudo_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

# Retrain Model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
num_epochs = 10

print("\nRetraining model with domain adaptation...")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in combined_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        total_loss += loss.item()

    accuracy = 100 * correct / total
    avg_loss = total_loss / len(combined_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

# Evaluate on test dataset
model.eval()
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        test_loss += loss.item()

test_accuracy = 100 * correct / total
test_loss /= len(test_loader)

print(f"\nFinal Test Accuracy: {test_accuracy:.2f}% | Test Loss: {test_loss:.4f}")

# Save the improved model
torch.save(model.state_dict(), "domain_adapted_model.pth")
print("\nImproved model saved as 'domain_adapted_model.pth'.")



Retraining model with domain adaptation...


TypeError: expected Tensor as element 1 in argument 0, but got int

In [None]:
for images, labels in combined_loader:
    print(f"Images type: {type(images)}, Labels type: {type(labels)}")
    print(f"Labels shape: {labels.shape}")  # Should be a tensor
    break


Images type: <class 'torch.Tensor'>, Labels type: <class 'torch.Tensor'>
Labels shape: torch.Size([32])


  label = torch.tensor(label, dtype=torch.long)


In [None]:
for img, lbl in pseudo_dataset:
    print(f"Image type: {type(img)}, Label type: {type(lbl)}, Label dtype: {lbl.dtype}")
    break


Image type: <class 'torch.Tensor'>, Label type: <class 'torch.Tensor'>, Label dtype: torch.int64


  label = torch.tensor(label, dtype=torch.long)


In [None]:
for images, labels in combined_loader:
    print(f"Images type: {type(images)}, Labels type: {type(labels)}")
    print(f"Labels shape: {labels.shape}, Labels dtype: {labels.dtype}")
    break


  label = torch.tensor(label, dtype=torch.long)


Images type: <class 'torch.Tensor'>, Labels type: <class 'torch.Tensor'>
Labels shape: torch.Size([32]), Labels dtype: torch.int64
