<a href="https://colab.research.google.com/github/cbradley264/UGTR_COD/blob/main/COD_570.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The goal of this code is to reimplement a paper titled "Uncertainty- Guided Transformer Reasoning for Camoflauged Object Detection".
Note: In order to make this work, you must add the TestDataset to your Google Drive, and from there you must include the path to the CHAMELEON dataset for both dataset_path and gt_path.

In [128]:
!pip freeze > requirements.txt

In [None]:
import torch
import os
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score

In [None]:
print(f'Can I can use GPU now? -- {torch.cuda.is_available()}')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
dataset_path = '/content/drive/My Drive/ECE570/TestDataset/CHAMELEON'
gt_path = '/content/drive/My Drive/ECE570/TestDataset/CHAMELEON/GT'
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomResizedCrop((128, 128), scale=(0.75, 1.25)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
def build_position_encoding(dim, height, width):
    y_embed = torch.linspace(0, 1, steps=height).unsqueeze(1).repeat(1, width)
    x_embed = torch.linspace(0, 1, steps=width).unsqueeze(0).repeat(height, 1)
    y_embed = y_embed.unsqueeze(0).expand(dim // 2, -1, -1)
    x_embed = x_embed.unsqueeze(0).expand(dim // 2, -1, -1)
    pos_encoding = torch.cat([x_embed, y_embed], dim=0).unsqueeze(0)
    return pos_encoding

In [None]:
class UGTR(nn.Module):
    def __init__(self, classes=1, zoom_factor=8, pretrained=True):
        super(UGTR, self).__init__()
        self.zoom_factor = zoom_factor
        self.classes = classes

        resnet = models.resnet50(pretrained=pretrained)
        self.layer0 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool
        )
        self.layer1, self.layer2, self.layer3, self.layer4 = (
            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
        )

        self.hidden_dim = 256
        self.input_proj = nn.Conv2d(2048, self.hidden_dim, kernel_size=1)
        self.channel_adjust = nn.Conv2d(1, 256, kernel_size=1)

        encoder_layer = TransformerEncoderLayer(
            d_model=self.hidden_dim, nhead=8, dim_feedforward=512, dropout=0.1
        )
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=3)

        self.pred = nn.Conv2d(self.hidden_dim, classes, kernel_size=1)
        self.mean_conv = nn.Conv2d(self.hidden_dim, 1, kernel_size=1)
        self.logvar_conv = nn.Conv2d(self.hidden_dim, 1, kernel_size=1)

    def reparameterize(self, mu, logvar, k=1):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        sample_z = eps.mul(std).add_(mu)
        return sample_z

    def reduce_channels(self, x):
        if not hasattr(self, 'reduce_conv'):
            self.reduce_conv = nn.Conv2d(x.shape[1], self.hidden_dim, kernel_size=1).to(x.device)
        return self.reduce_conv(x) # Use the existing layer

    def forward(self, x):
        x_size = x.size()
        h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
        w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)

        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # x = self.reduce_channels(x)
        x = self.input_proj(x)

        pos_encoding = build_position_encoding(self.hidden_dim, x.shape[2], x.shape[3]).to(x.device)
        x = x + pos_encoding

        batch_size, c, height, width = x.shape
        x = x.flatten(2).permute(2, 0, 1)
        x = self.transformer_encoder(x)
        x = x.permute(1, 2, 0).view(batch_size, c, height, width)

        mean = self.mean_conv(x)
        logvar = self.logvar_conv(x)

        prob_x = self.reparameterize(mean, logvar, k=1)
        uncertainty = torch.exp(logvar)

        prob_x = self.channel_adjust(prob_x)
        pred = torch.sigmoid(self.pred(prob_x))

        if self.zoom_factor != 1:
            pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=True)
            uncertainty = F.interpolate(uncertainty, size=(h, w), mode='bilinear', align_corners=True)

        return pred, uncertainty


In [None]:
def poly_lr_scheduler(optimizer, base_lr, iter, max_iter, power):
  lr = base_lr * (1 - iter / max_iter) ** power
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

In [None]:
def train(model, dataloader, criterion, optimizer, device, num_epochs=10, checkpoint_path='model_checkpoint.pth', base_lr=1e-4, power=0.9):
    model.train()
    model.to(device)

    max_iter = len(dataloader) * num_epochs
    global_iter = 0
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        epoch_uncertainty_loss = 0.0
        num_batches = len(dataloader)

        with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
            for images, targets in pbar:
                images, targets = images.to(device), targets.to(device)

                optimizer.zero_grad()
                outputs, uncertainty = model(images)
                targets = targets.view(-1, 1, 1, 1).repeat(1, 1, outputs.shape[2], outputs.shape[3]).float()
                main_loss = criterion(outputs, targets)
                uncertainty_loss = torch.mean(uncertainty)
                loss = main_loss + 0.1 * uncertainty_loss

                loss.backward()
                optimizer.step()
                poly_lr_scheduler(optimizer, base_lr, global_iter, max_iter, power)
                epoch_loss += main_loss.item()
                epoch_uncertainty_loss += uncertainty_loss.item()
                pbar.set_postfix({
                    'Loss': epoch_loss / (pbar.n + 1),
                    'Uncertainty Loss': epoch_uncertainty_loss / (pbar.n + 1)
                })

        torch.save(model.state_dict(), checkpoint_path)
        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {epoch_loss / num_batches:.4f}, Uncertainty Loss: {epoch_uncertainty_loss / num_batches:.4f}")

    print("Training completed.")

In [None]:
def test(model, dataloader, criterion, device):
    model.eval()
    model.to(device)
    total_loss = 0.0
    uncertainty_loss = 0.0
    num_batches = len(dataloader)

    with torch.no_grad():
        with tqdm(dataloader, desc="Testing", unit="batch") as pbar:
            for images, targets in pbar:
                images, targets = images.to(device), targets.to(device)

                outputs, uncertainty = model(images)
                targets = targets.view(-1, 1, 1, 1).repeat(1, 1, outputs.shape[2], outputs.shape[3]).float()
                main_loss = criterion(outputs, targets)
                uncertainty_loss_batch = torch.mean(uncertainty)
                loss = main_loss + 0.1 * uncertainty_loss_batch

                total_loss += main_loss.item()
                uncertainty_loss += uncertainty_loss_batch.item()

                pbar.set_postfix({
                    'Loss': total_loss / (pbar.n + 1),
                    'Uncertainty Loss': uncertainty_loss / (pbar.n + 1)
                })

    avg_loss = total_loss / num_batches
    avg_uncertainty_loss = uncertainty_loss / num_batches
    print(f"Test Loss: {avg_loss:.4f}, Uncertainty Loss: {avg_uncertainty_loss:.4f}")
    return avg_loss, avg_uncertainty_loss

In [None]:
def evaluate_model(model, dataloader, device, ground_truth_folder, dataset):
    model.eval()
    model.to(device)

    mae_scores = []
    e_scores = []
    s_scores = []
    f_scores = []

    with torch.no_grad():
      with tqdm(dataloader, desc="Evaluating", unit="batch") as pbar:
        for images, targets in pbar:
          images = images.to(device)

          predictions, _ = model(images)
          predictions = predictions.squeeze(1)

          predicted_masks = (predictions > 0.5).float()

          # for idx in range(len(targets)):
          for idx, animal_num in enumerate(targets):
            gt_path = f"{ground_truth_folder}/animal-{76-animal_num}.png"
            # image_path = dataset.imgs[targets[idx]][0]
            # image_name = os.path.basename(image_path)
            # gt_name = os.path.splitext(image_name)[0] + ".png"
            # gt_path = os.path.join(ground_truth_folder, gt_name)
            ground_truth = load_ground_truth(gt_path).to(device)
            ground_truth = transforms.Resize(predicted_masks[idx].shape[-2:])(ground_truth)

            mae = torch.mean(torch.abs(predicted_masks[idx] - ground_truth)).item()
            mae_scores.append(mae)

            e_measure = calculate_e_measure(predicted_masks[idx], ground_truth)
            e_scores.append(e_measure)

            s_measure = calculate_s_measure(predicted_masks[idx], ground_truth)
            s_scores.append(s_measure)

            f_measure = calculate_f_measure(predicted_masks[idx], ground_truth)
            f_scores.append(f_measure)

    mean_mae = np.mean(mae_scores)
    mean_e = np.mean(e_scores)
    mean_s = np.mean(s_scores)
    mean_f = np.mean(f_scores)

    print(f"Mean MAE: {mean_mae:.4f}")
    print(f"Mean E-measure (Eφ): {mean_e:.4f}")
    print(f"Mean S-measure (Sα): {mean_s:.4f}")
    print(f"Mean Weighted F-measure (Fwβ): {mean_f:.4f}")

def load_ground_truth(filepath):
    """Load ground truth as a tensor from a given path."""
    from PIL import Image
    import torchvision.transforms as transforms

    image = Image.open(filepath).convert("L")
    transform = transforms.ToTensor()
    tensor_image = transform(image)
    tensor_image = (tensor_image > 0.5).float()
    return tensor_image

def calculate_e_measure(pred, gt):
    """Calculate the mean E-measure (Eφ)."""
    pred = pred.cpu().numpy()
    gt = gt.cpu().numpy()
    gt_mean = np.mean(gt)
    precision_map = (pred * gt_mean + gt * np.mean(pred)) / (gt_mean + np.mean(pred) + 1e-8)
    e_measure = np.mean(precision_map)
    return e_measure

def calculate_s_measure(pred, gt):
    """Calculate the mean S-measure (Sα)."""
    pred = pred.cpu().numpy()
    gt = gt.cpu().numpy()
    alpha = 0.5
    obj_score = 2 * np.sum(pred * gt) / (np.sum(pred) + np.sum(gt) + 1e-8)

    pred_fg = pred * (gt >= 0.5)
    gt_fg = gt * (gt >= 0.5)
    pred_bg = pred * (gt < 0.5)
    gt_bg = gt * (gt < 0.5)

    region_fg = 2 * np.sum(pred_fg * gt_fg) / (np.sum(pred_fg) + np.sum(gt_fg) + 1e-8)
    region_bg = 2 * np.sum(pred_bg * gt_bg) / (np.sum(pred_bg) + np.sum(gt_bg) + 1e-8)

    s_measure = alpha * obj_score + (1 - alpha) * 0.5 * (region_fg + region_bg)
    return s_measure

def calculate_f_measure(pred, gt, beta_square=0.3):
    """Calculate the Weighted F-measure (Fwβ)."""
    pred = pred.cpu().numpy().flatten()
    gt = gt.cpu().numpy().flatten()
    gt = (gt > 0.5).astype(int)
    precision = precision_score(gt, pred, zero_division=0)
    recall = recall_score(gt, pred, zero_division=0)

    if precision + recall == 0:
        return 0.0

    f_measure = (1 + beta_square) * (precision * recall) / (beta_square * precision + recall + 1e-8)
    return f_measure

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UGTR()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=1e-4)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# train(model, train_loader, criterion, optimizer, device, num_epochs=20, checkpoint_path='model_checkpoint.pth', base_lr=1e-4, power=0.9)
# test_loss, test_uncertainty_loss = test(model, test_loader, criterion, device)

evaluate_model(model, test_loader, device, gt_path, test_dataset)
