In [1]:
import os
# Disable symlink warnings
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'

In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import models
from transformers import DeiTForImageClassification
from PIL import Image
import numpy as np
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Define a custom dataset for property images
class PropertyImageDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith(('.jpg', '.png'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, img_path  # Return image and its path for reference

# Define SimCLR model
class SimCLR(nn.Module):
    def __init__(self, base_model, out_dim):
        super(SimCLR, self).__init__()
        self.backbone = nn.Sequential(*list(base_model.children())[:-1])  # Remove the classification layer
        self.projection_head = nn.Sequential(
            nn.Linear(base_model.fc.in_features, 512),
            nn.BatchNorm1d(512),  # Batch Norm
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )

    def forward(self, x):
        h = self.backbone(x).squeeze()  # Backbone representation
        z = self.projection_head(h)  # Projection
        return h, z

In [4]:
# Load trained SimCLR model
def load_simclr_model(model_path='simclr_model1.pth'):
    base_model = models.resnet50(weights='ResNet50_Weights.DEFAULT')  # Load ResNet50
    simclr_model = SimCLR(base_model, out_dim=128)  # SimCLR instance
    state_dict = torch.load(model_path, weights_only=True)
    simclr_model.load_state_dict(state_dict, strict=False)
    simclr_model.eval()  # Evaluation mode
    return simclr_model

# Enhanced SimCLR augmentations
class ImprovedSimCLRTransform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),  # Added vertical flip
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
            transforms.RandomRotation(15),
            transforms.RandomPerspective(distortion_scale=0.5, p=0.5),  # Added perspective transformation
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __call__(self, x):
        return self.transform(x)

# Cosine Similarity Contrastive Loss
class CosineContrastiveLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(CosineContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        cosine_similarity = nn.functional.cosine_similarity(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(cosine_similarity, 2) +
                          label * torch.pow(nn.functional.relu(self.margin - cosine_similarity), 2))
        return loss

In [5]:
# Feature extraction using SimCLR
def extract_features(simclr_model, dataloader):
    features = {}
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    simclr_model.to(device)

    total_images = len(dataloader.dataset)
    processed_images = 0

    with torch.no_grad():
        for step, (images, img_paths) in enumerate(dataloader):
            images = images.to(device)
            _, z = simclr_model(images)

            # Save features
            for img_path, feature in zip(img_paths, z.cpu().numpy()):
                features[img_path] = feature
            
            processed_images += len(images)
            print(f"Processed {processed_images}/{total_images} images")

    return features

In [6]:
# Improved data augmentation for DeiT
class PropertyImageDatasetDeiT(Dataset):
    def __init__(self, image_paths, feature_dict, transform=None):
        self.image_paths = image_paths
        self.feature_dict = feature_dict
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        features = self.feature_dict[img_path]
        return image, torch.tensor(features)


In [7]:
# DeiT fine-tuning with gradual unfreezing
def fine_tune_deit(model, dataloader, epochs=10, initial_lr=1e-4, unfreeze_after=5):
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    criterion = CosineContrastiveLoss()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    # Add a flag to unfreeze once
    unfreezed = False
    
    for epoch in range(epochs):
        print(f"Starting epoch {epoch + 1}/{epochs}")
        model.train()
        epoch_loss = 0
        for step, (images, img_features) in enumerate(dataloader):
            print(f"Processing step {step + 1}/{len(dataloader)}")
            images = images.to(device)
            img_features = img_features.to(device)

            outputs = model(images).logits
            batch_size = outputs.size(0)

            labels = torch.zeros(batch_size).to(device)
            pos_indices = torch.randperm(batch_size)[:batch_size // 2]
            labels[pos_indices] = 1

            loss = criterion(outputs, img_features, labels)
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Unfreeze layers once after specified epochs
        if epoch >= unfreeze_after and not unfreezed:
            for param in model.deit.parameters():
                param.requires_grad = True
            unfreezed = True  # Set the flag to True once unfreezing happens

        scheduler.step()
        print(f"Epoch [{epoch + 1}/{epochs}] completed. Average Loss: {epoch_loss / len(dataloader):.4f}")

    return model

In [8]:
# Data preparation and model fine-tuning process
image_folder = r"C:\images"  # Path to your images
transform = ImprovedSimCLRTransform()
dataset = PropertyImageDataset(image_folder=image_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

simclr_model = load_simclr_model()  # Load SimCLR model
simclr_features = extract_features(simclr_model, dataloader)

deit_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

deit_dataset = PropertyImageDatasetDeiT(image_paths=list(simclr_features.keys()), feature_dict=simclr_features, transform=deit_transform)
deit_dataloader = DataLoader(deit_dataset, batch_size=32, shuffle=True)  # Reduce from 32 to 16

Processed 32/8158 images
Processed 64/8158 images
Processed 96/8158 images
Processed 128/8158 images
Processed 160/8158 images
Processed 192/8158 images
Processed 224/8158 images
Processed 256/8158 images
Processed 288/8158 images
Processed 320/8158 images
Processed 352/8158 images
Processed 384/8158 images
Processed 416/8158 images
Processed 448/8158 images
Processed 480/8158 images
Processed 512/8158 images
Processed 544/8158 images
Processed 576/8158 images
Processed 608/8158 images
Processed 640/8158 images
Processed 672/8158 images
Processed 704/8158 images
Processed 736/8158 images
Processed 768/8158 images
Processed 800/8158 images
Processed 832/8158 images
Processed 864/8158 images
Processed 896/8158 images
Processed 928/8158 images
Processed 960/8158 images
Processed 992/8158 images
Processed 1024/8158 images
Processed 1056/8158 images
Processed 1088/8158 images
Processed 1120/8158 images
Processed 1152/8158 images
Processed 1184/8158 images
Processed 1216/8158 images
Processe

In [9]:
# Load Pre-Trained DeiT Model
num_classes = 128
deit_model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224', num_labels=num_classes)

Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# Initially freeze all DeiT layers
for param in deit_model.deit.parameters():
    param.requires_grad = False

In [11]:
# Fine-tune the model with gradual unfreezing and enhanced settings
deit_model = fine_tune_deit(deit_model, deit_dataloader, epochs=10, initial_lr=1e-4, unfreeze_after=5)

Starting epoch 1/10
Processing step 1/255


  context_layer = torch.nn.functional.scaled_dot_product_attention(


Processing step 2/255
Processing step 3/255
Processing step 4/255
Processing step 5/255
Processing step 6/255
Processing step 7/255
Processing step 8/255
Processing step 9/255
Processing step 10/255
Processing step 11/255
Processing step 12/255
Processing step 13/255
Processing step 14/255
Processing step 15/255
Processing step 16/255
Processing step 17/255
Processing step 18/255
Processing step 19/255
Processing step 20/255
Processing step 21/255
Processing step 22/255
Processing step 23/255
Processing step 24/255
Processing step 25/255
Processing step 26/255
Processing step 27/255
Processing step 28/255
Processing step 29/255
Processing step 30/255
Processing step 31/255
Processing step 32/255
Processing step 33/255
Processing step 34/255
Processing step 35/255
Processing step 36/255
Processing step 37/255
Processing step 38/255
Processing step 39/255
Processing step 40/255
Processing step 41/255
Processing step 42/255
Processing step 43/255
Processing step 44/255
Processing step 45/

In [12]:
# Save the fine-tuned DeiT model
torch.save(deit_model.state_dict(), 'deit_finetuned_improved.pth')
print("Fine-tuning complete. Model saved as 'deit_finetuned_improved.pth'.")

Fine-tuning complete. Model saved as 'deit_finetuned_improved.pth'.
