In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import timm
import copy
import os

# --- DINO Loss ---
class DINOLoss(nn.Module):
    def __init__(self, out_dim, warmup_teacher_temp=0.04, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp
        self.center_momentum = center_momentum
        self.center = torch.zeros(1, out_dim)

    def forward(self, student_output, teacher_output):
        # student_output: [bs, out_dim]
        # teacher_output: [bs, out_dim]
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(2)

        teacher_out = teacher_output - self.center
        teacher_out = teacher_out / self.teacher_temp
        teacher_out = teacher_out.chunk(2)

        new_center = teacher_output.mean(dim=0, keepdim=True)
        self.center = self.center * self.center_momentum + (1 - self.center_momentum) * new_center

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue
                loss = torch.sum(-q * torch.log_softmax(student_out[v], dim=-1), dim=-1).mean()
                total_loss += loss
                n_loss_terms += 1
        total_loss /= n_loss_terms

        # update center
        self.center = self.center * self.center_momentum + (1 - self.center_momentum) * teacher_output.mean(dim=0, keepdim=True)

        return total_loss

# --- Simple DINO wrapper ---
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=65536, use_bn=False, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        layers = []
        if nlayers == 1:
            layers.append(nn.Linear(in_dim, bottleneck_dim))
        else:
            layers.append(nn.Linear(in_dim, hidden_dim))
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
        self.mlp = nn.Sequential(*layers)
        # self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer = nn.Linear(bottleneck_dim, out_dim, bias=False)


    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1)
        x = self.last_layer(x)
        return x

# --- Data augmentations (Multi-crop as in DINO) ---
class DINOTransform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]  # Two views

# --- Dataset and DataLoader ---
from torch.utils.data import Dataset
from PIL import Image
import glob

class UnlabeledImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.paths = glob.glob(os.path.join(folder_path, "*.*"))
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            views = self.transform(image)
        return views, 0  # Dummy label

# Usar con transformación DINO
dataset = UnlabeledImageDataset("D:/iceberg1/tl_2018_40001/", transform=DINOTransform())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

# --- Model setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit = timm.create_model("vit_small_patch16_224", pretrained=False)
checkpoint = torch.load("I:/vit_small_patch16_224/dino_deitsmall16_pretrain.pth", map_location="cpu")
vit.load_state_dict(checkpoint, strict=False) 
embed_dim = vit.embed_dim
student = nn.Sequential(vit, DINOHead(embed_dim)).to(device)
teacher = copy.deepcopy(student).to(device)
for p in teacher.parameters():
    p.requires_grad = False

# --- Optimizer and loss ---
optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)
criterion = DINOLoss(out_dim=65536)

# --- Training loop ---
epochs = 100
momentum = 0.996

for epoch in range(epochs):
    student.train()
    total_loss = 0

    for views, _ in dataloader:
        views = [v.to(device) for v in views]
        student_output = torch.cat([student(v) for v in views], dim=0)
        with torch.no_grad():
            teacher_output = torch.cat([teacher(v) for v in views], dim=0)

        loss = criterion(student_output, teacher_output)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        # update teacher (EMA)
        for param_q, param_k in zip(student.parameters(), teacher.parameters()):
            param_k.data.mul_(momentum).add_((1 - momentum) * param_q.data)

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

torch.save(student.state_dict(), "vit_dino_pretrained.pth")
 

  checkpoint = torch.load("I:/vit_small_patch16_224/dino_deitsmall16_pretrain.pth", map_location="cpu")


In [None]:
vit_backbone = timm.create_model("vit_small_patch16_224", pretrained=False)
vit_backbone.load_state_dict(torch.load("vit_dino_pretrained.pth"), strict=False)
vit_backbone.head = nn.Identity()

# Ejemplo: clasificador para 2 clases
classifier = nn.Sequential(
    vit_backbone,
    nn.Linear(vit_backbone.embed_dim, 2)
)
