# Initial DINO model

This notebook has the first pass at creating a tree classification model based on DINOv2. It takes as input a tree image, and outputs a class. The available classes are determined by the `classes.json` file for the Laurentian trees dataset.

DINOv2 is run over the image, and then linear probing is done on the output tokens. It's got really bad accuracy and is having difficulty learning.

A number of things could be tried here to improve it:
- Using KNN instead of linear probing, since these classes may not be linearly separable even in token space
- Token-level classification (Martin thinks this is the best approach)
- Adjusting the image pre-processing. Currently they're all scaled to 256*256, which could cause stretching for non-square inputs
- Only using tokens where the tree is located. Since one version of the input uses masked tree images with transparent backgrounds, we could try to just use the non-transparent pixels.

In [1]:
import torch
import torch.nn as nn
import os
from transformers import Dinov2Model, Dinov2PreTrainedModel
from PIL import Image
import json
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as A
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create dataloaders

class TreeDataset(Dataset):
    def __init__(self, root_dir, transform, zones=['Z1', 'Z2'], class_map_path='./data/classes.json'):
        with open(class_map_path, 'r') as f:
            self.class_map = {v:int(k) for k, v in json.load(f).items()}
        
        self.root_dir = root_dir
        self.classes = [d for d in os.listdir(root_dir)]
        self.image_files = []
        self.transform = transform
        for c in self.classes:
            for img in os.listdir(os.path.join(root_dir, c)):
                if any(z for z in zones if z in img):
                    self.image_files.append((os.path.join(root_dir, c, img), self.class_map[c]))
        self.toTensor = transforms.ToTensor()

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path, c = self.image_files[idx]
        img_np = np.array(Image.open(img_path))[:,:,:3]
        transformed = self.transform(image=img_np)['image']
        return self.toTensor(transformed), torch.tensor(c)

# These are the mean/std I took from the complete tiff of Z1
ADE_MEAN = np.array([51.61087416176021, 70.54108897685563, 43.65073194868197]) / 255
ADE_STD = np.array([66.21302035582556, 82.09431586857384, 54.93294965405881]) / 255

train_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

val_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

train_dataset = TreeDataset('./data/tree_classification', train_transform)
val_dataset = TreeDataset('./data/tree_classification', val_transform, zones=['Z3'])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)



# Pick this up tomorrow: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb


In [3]:
class Dinov2ForClassification(Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    self.config = config
    self.dinov2 = Dinov2Model(config)
    self.classifier = nn.Linear(config.hidden_size * 256, config.num_labels)


  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None):
    # use frozen features
    outputs = self.dinov2(pixel_values,
                            output_hidden_states=output_hidden_states,
                            output_attentions=output_attentions)
    
    # get the patch embeddings - so we exclude the CLS token
    patch_embeddings = torch.flatten(outputs.last_hidden_state[:,1:,:], start_dim=1)

    # convert to logits and upsample to the size of the pixel values
    logits = self.classifier(patch_embeddings)

    return nn.functional.softmax(logits, dim=1)

model = Dinov2ForClassification.from_pretrained("facebook/dinov2-base", num_labels=29)

# freeze DINOv2 parameters
for name, param in model.named_parameters():
  if name.startswith("dinov2"):
    param.requires_grad = False

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

Some weights of Dinov2ForClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Dinov2ForClassification(
  (dinov2): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2Attention(
            (attention): Dinov2SelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            

In [4]:
learning_rate = 1e-6
epochs = 10

optimizer = AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
model.train()
epoch_train_losses = []
epoch_val_losses = []

for epoch in range(epochs):
  train_loss = 0
  val_loss = 0
  print("Epoch:", epoch)

  for idx, batch in enumerate(tqdm(train_dataloader)):
    data = batch[0].to(device)
    targets = nn.functional.one_hot(batch[1], num_classes=29).float().to(device)
    
    # forward pass
    outputs = model(data)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()

    # zero the parameter gradients
    optimizer.zero_grad()

    train_loss += loss.item() / targets.shape[0]
  
  train_loss /= len(train_dataloader)
  print(f'  Train loss: {train_loss}')

  for idx, batch in enumerate(tqdm(val_dataloader)):
    data = batch[0].to(device)
    targets = nn.functional.one_hot(batch[1], num_classes=29).float().to(device)
    outputs = model(data)
    loss = loss_fn(outputs, targets)
    optimizer.zero_grad()
    val_loss += loss.item() / targets.shape[0]
  
  val_loss /= len(val_dataloader)
  print(f'  Val loss: {val_loss}')
  epoch_train_losses.append(train_loss)
  epoch_val_losses.append(val_loss)


Epoch: 0


100%|██████████| 1143/1143 [01:50<00:00, 10.34it/s]


  Train loss: 0.1987968183538941


100%|██████████| 291/291 [00:26<00:00, 10.95it/s]


  Val loss: 0.20195730768937833
Epoch: 1


100%|██████████| 1143/1143 [01:45<00:00, 10.86it/s]


  Train loss: 0.1961812807714741


100%|██████████| 291/291 [00:25<00:00, 11.45it/s]


  Val loss: 0.19953692594344674
Epoch: 2


100%|██████████| 1143/1143 [01:44<00:00, 10.91it/s]


  Train loss: 0.19311613029590952


100%|██████████| 291/291 [00:25<00:00, 11.47it/s]


  Val loss: 0.19674139375129515
Epoch: 3


100%|██████████| 1143/1143 [01:46<00:00, 10.77it/s]


  Train loss: 0.19124460259447573


100%|██████████| 291/291 [00:25<00:00, 11.40it/s]


  Val loss: 0.19552859188764773
Epoch: 4


100%|██████████| 1143/1143 [01:45<00:00, 10.87it/s]


  Train loss: 0.1904920567078019


100%|██████████| 291/291 [00:25<00:00, 11.38it/s]


  Val loss: 0.19482034156002948
Epoch: 5


100%|██████████| 1143/1143 [01:45<00:00, 10.86it/s]


  Train loss: 0.1897978750583068


100%|██████████| 291/291 [00:25<00:00, 11.63it/s]


  Val loss: 0.19429856982222946
Epoch: 6


100%|██████████| 1143/1143 [01:44<00:00, 10.94it/s]


  Train loss: 0.18930892501603275


100%|██████████| 291/291 [00:25<00:00, 11.46it/s]


  Val loss: 0.1939089653418236
Epoch: 7


100%|██████████| 1143/1143 [01:43<00:00, 11.00it/s]


  Train loss: 0.188947942065546


100%|██████████| 291/291 [00:24<00:00, 11.86it/s]


  Val loss: 0.19351345022109775
Epoch: 8


100%|██████████| 1143/1143 [01:44<00:00, 10.92it/s]


  Train loss: 0.18862033156272306


100%|██████████| 291/291 [00:25<00:00, 11.49it/s]


  Val loss: 0.19318987811144275
Epoch: 9


100%|██████████| 1143/1143 [01:45<00:00, 10.85it/s]


  Train loss: 0.18819623375971486


100%|██████████| 291/291 [00:24<00:00, 11.69it/s]

  Val loss: 0.1928560747928226





In [5]:
print(epoch_train_losses)
print(epoch_val_losses)

[0.1987968183538941, 0.1961812807714741, 0.19311613029590952, 0.19124460259447573, 0.1904920567078019, 0.1897978750583068, 0.18930892501603275, 0.188947942065546, 0.18862033156272306, 0.18819623375971486]
[0.20195730768937833, 0.19953692594344674, 0.19674139375129515, 0.19552859188764773, 0.19482034156002948, 0.19429856982222946, 0.1939089653418236, 0.19351345022109775, 0.19318987811144275, 0.1928560747928226]
