In [None]:
!pip install torch torchvision transformers pillow matplotlib numpy tqdm datasets pycocotools -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from transformers import BertModel, BertTokenizer
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from datasets import load_dataset

print("✓ Libraries imported")

In [None]:
from pycocotools.coco import COCO
from PIL import Image
import os
import urllib.request
import zipfile
import random

print("Downloading COCO 2017 validation dataset...")
print("This uses REAL COCO images with generated referring expressions")

os.makedirs("coco_data", exist_ok=True)

coco_val_url = "http://images.cocodataset.org/zips/val2017.zip"
coco_ann_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"

if not os.path.exists("coco_data/val2017"):
    if not os.path.exists("coco_data/val2017.zip"):
        print("Downloading validation images (1GB)...")
        urllib.request.urlretrieve(coco_val_url, "coco_data/val2017.zip")

    print("Extracting images...")
    with zipfile.ZipFile("coco_data/val2017.zip", 'r') as zip_ref:
        zip_ref.extractall("coco_data/")
    print("✓ Images extracted")

if not os.path.exists("coco_data/annotations"):
    if not os.path.exists("coco_data/annotations_trainval2017.zip"):
        print("Downloading annotations...")
        urllib.request.urlretrieve(coco_ann_url, "coco_data/annotations_trainval2017.zip")

    print("Extracting annotations...")
    with zipfile.ZipFile("coco_data/annotations_trainval2017.zip", 'r') as zip_ref:
        zip_ref.extractall("coco_data/")
    print("✓ Annotations extracted")

print("Loading COCO dataset...")
coco = COCO("coco_data/annotations/instances_val2017.json")

cat_names = {cat['id']: cat['name'] for cat in coco.loadCats(coco.getCatIds())}

print(f"✓ Loaded COCO with {len(coco.getImgIds())} images and {len(cat_names)} categories")

class COCOReferringDataset:
    def __init__(self, coco, max_samples=10000):
        self.coco = coco
        self.cat_names = {cat['id']: cat['name'] for cat in coco.loadCats(coco.getCatIds())}

        self.samples = []
        img_ids = coco.getImgIds()

        for img_id in img_ids:
            ann_ids = coco.getAnnIds(imgIds=img_id)
            anns = coco.loadAnns(ann_ids)

            category_counts = {}
            for ann in anns:
                if ann['area'] > 1000 and not ann.get('iscrowd', 0):
                    cat_id = ann['category_id']
                    category_counts[cat_id] = category_counts.get(cat_id, 0) + 1

            for ann in anns:
                if ann['area'] > 1000 and not ann.get('iscrowd', 0):
                    cat_id = ann['category_id']

                    if category_counts[cat_id] == 1:
                        cat_name = self.cat_names[cat_id]
                        text = f"the {cat_name}"

                        self.samples.append({
                            'image_id': img_id,
                            'ann_id': ann['id'],
                            'text': text,
                            'category': cat_name
                        })

                        if len(self.samples) >= max_samples:
                            break

            if len(self.samples) >= max_samples:
                break

        print(f"✓ Created {len(self.samples)} referring expression samples")
        print(f"  (Only using images with single instance per category)")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        img_info = self.coco.loadImgs(sample['image_id'])[0]
        img_path = f"coco_data/val2017/{img_info['file_name']}"
        image = Image.open(img_path).convert('RGB')

        ann = self.coco.loadAnns(sample['ann_id'])[0]
        mask = self.coco.annToMask(ann)

        return {
            'image': image,
            'mask': mask,
            'text': sample['text'],
            'image_id': sample['image_id']
        }

base_dataset = COCOReferringDataset(coco, max_samples=10000)

train_size = int(0.8 * len(base_dataset))
val_size = len(base_dataset) - train_size

train_indices = list(range(train_size))
val_indices = list(range(train_size, len(base_dataset)))

train_base = torch.utils.data.Subset(base_dataset, train_indices)
val_base = torch.utils.data.Subset(base_dataset, val_indices)

print(f"Split: {len(train_base)} train, {len(val_base)} val")

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import BertModel

class VisionEncoder(nn.Module):
    def __init__(self, backbone='resnet50', pretrained=True, out_channels=256):
        super().__init__()

        if backbone == 'resnet50':
            resnet = models.resnet50(pretrained=pretrained)
        else:
            resnet = models.resnet101(pretrained=pretrained)

        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.lateral4 = nn.Conv2d(2048, out_channels, 1)
        self.lateral3 = nn.Conv2d(1024, out_channels, 1)
        self.lateral2 = nn.Conv2d(512, out_channels, 1)
        self.lateral1 = nn.Conv2d(256, out_channels, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        c1 = self.layer1(x)
        c2 = self.layer2(c1)
        c3 = self.layer3(c2)
        c4 = self.layer4(c3)

        p4 = self.lateral4(c4)
        p3 = self.lateral3(c3) + nn.functional.interpolate(p4, size=c3.shape[-2:], mode='nearest')
        p2 = self.lateral2(c2) + nn.functional.interpolate(p3, size=c2.shape[-2:], mode='nearest')
        p1 = self.lateral1(c1) + nn.functional.interpolate(p2, size=c1.shape[-2:], mode='nearest')

        return p1, p2, p3, p4

class LanguageEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', hidden_dim=256):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.projection = nn.Linear(768, hidden_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = outputs.last_hidden_state
        projected = self.projection(last_hidden)
        return projected

class MultiModalFusion(nn.Module):
    def __init__(self, hidden_dim=256, num_heads=8, num_layers=2, dropout=0.1):
        super().__init__()

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True
        )

        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, visual_features, language_features, language_mask=None):
        B, C, H, W = visual_features.shape
        visual_flat = visual_features.view(B, C, H * W).permute(0, 2, 1)

        combined = torch.cat([visual_flat, language_features], dim=1)

        fused = self.transformer(combined)

        visual_fused = fused[:, :H*W, :].permute(0, 2, 1).view(B, C, H, W)

        return visual_fused

class SegmentationDecoder(nn.Module):
    def __init__(self, in_channels=256, num_classes=1):
        super().__init__()

        self.up5 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
            nn.BatchNorm2d(in_channels // 2),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )

        self.up4 = nn.Sequential(
            nn.Conv2d(in_channels // 2, in_channels // 4, 3, padding=1),
            nn.BatchNorm2d(in_channels // 4),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )

        self.up3 = nn.Sequential(
            nn.Conv2d(in_channels // 4, in_channels // 8, 3, padding=1),
            nn.BatchNorm2d(in_channels // 8),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )

        self.up2 = nn.Sequential(
            nn.Conv2d(in_channels // 8, in_channels // 16, 3, padding=1),
            nn.BatchNorm2d(in_channels // 16),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )

        self.up1 = nn.Sequential(
            nn.Conv2d(in_channels // 16, in_channels // 32, 3, padding=1),
            nn.BatchNorm2d(in_channels // 32),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )

        self.final = nn.Conv2d(in_channels // 32, num_classes, 1)

    def forward(self, x):
        x = self.up5(x)
        x = self.up4(x)
        x = self.up3(x)
        x = self.up2(x)
        x = self.up1(x)
        x = self.final(x)
        return x

class ReferringSegmentationModel(nn.Module):
    def __init__(self, backbone='resnet50', pretrained=True, hidden_dim=256):
        super().__init__()

        self.vision_encoder = VisionEncoder(backbone, pretrained, hidden_dim)
        self.language_encoder = LanguageEncoder('bert-base-uncased', hidden_dim)
        self.fusion = MultiModalFusion(hidden_dim, num_heads=8, num_layers=2)
        self.decoder = SegmentationDecoder(hidden_dim, num_classes=1)

    def forward(self, images, input_ids, attention_mask):
        p1, p2, p3, p4 = self.vision_encoder(images)

        lang_features = self.language_encoder(input_ids, attention_mask)

        fused = self.fusion(p4, lang_features, attention_mask)

        mask = self.decoder(fused)

        return torch.sigmoid(mask.squeeze(1)), fused

print("✓ Model architecture defined")

In [None]:
import torch.nn.functional as F

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = pred.contiguous().view(-1)
        target = target.contiguous().view(-1)

        intersection = (pred * target).sum()
        dice = (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)

        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self, bce_weight=1.0, dice_weight=1.0):
        super().__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.dice_loss = DiceLoss()

    def forward(self, pred, target):
        bce = F.binary_cross_entropy(pred, target)
        dice = self.dice_loss(pred, target)
        total = self.bce_weight * bce + self.dice_weight * dice
        return total

def compute_iou(pred, target, threshold=0.5):
    pred_binary = (pred > threshold).float()
    target_binary = target.float()

    intersection = (pred_binary * target_binary).sum(dim=(1, 2))
    union = pred_binary.sum(dim=(1, 2)) + target_binary.sum(dim=(1, 2)) - intersection

    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou

def compute_dice(pred, target, threshold=0.5):
    pred_binary = (pred > threshold).float()
    target_binary = target.float()

    intersection = (pred_binary * target_binary).sum(dim=(1, 2))
    dice = (2. * intersection + 1e-6) / (pred_binary.sum(dim=(1, 2)) + target_binary.sum(dim=(1, 2)) + 1e-6)

    return dice

print("✓ Loss functions defined")

In [None]:
import torchvision.transforms as transforms
from transformers import BertTokenizer

class RefCOCODataset(Dataset):
    def __init__(self, base_dataset, tokenizer, image_size=320, max_length=20):
        self.base_dataset = base_dataset
        self.tokenizer = tokenizer
        self.image_size = image_size
        self.max_length = max_length

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        sample = self.base_dataset[idx]

        image = self.transform(sample['image'])

        mask = torch.from_numpy(sample['mask']).float()
        mask = torch.nn.functional.interpolate(
            mask.unsqueeze(0).unsqueeze(0),
            size=(self.image_size, self.image_size),
            mode='nearest'
        ).squeeze()

        tokens = self.tokenizer(
            sample['text'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'image': image,
            'mask': mask,
            'input_ids': tokens['input_ids'].squeeze(0),
            'attention_mask': tokens['attention_mask'].squeeze(0)
        }

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = RefCOCODataset(train_base, tokenizer)
val_dataset = RefCOCODataset(val_base, tokenizer)

print(f"✓ Datasets ready: {len(train_dataset)} train, {len(val_dataset)} val")

In [None]:
import matplotlib.pyplot as plt

sample = train_dataset[0]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

img_np = sample['image'].permute(1, 2, 0).numpy()
img_np = img_np * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
img_np = np.clip(img_np, 0, 1)

axes[0].imshow(img_np)
axes[0].set_title(f"Text: '{train_base[0]['text']}'")
axes[0].axis('off')

axes[1].imshow(sample['mask'].numpy(), cmap='gray')
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print(f"Image shape: {sample['image'].shape}")
print(f"Mask shape: {sample['mask'].shape}")
print(f"Text tokens shape: {sample['input_ids'].shape}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

model = ReferringSegmentationModel().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

criterion = CombinedLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

batch_size = 8

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"✓ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}\n")

num_epochs = 15
best_iou = 0.0
train_losses = []
val_ious = []

print("Training...\n")

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for batch_idx, batch in enumerate(pbar):
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        optimizer.zero_grad()
        pred_masks, _ = model(images, input_ids, attention_mask)
        loss = criterion(pred_masks, masks)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        if batch_idx >= 100:
            break

    avg_loss = epoch_loss / min(len(train_loader), 100)
    train_losses.append(avg_loss)

    model.eval()
    val_iou_list = []
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pred_masks, _ = model(images, input_ids, attention_mask)
            iou = compute_iou(pred_masks, masks)
            val_iou_list.extend(iou.cpu().numpy())
            if batch_idx >= 50:
                break

    avg_iou = np.mean(val_iou_list) * 100
    val_ious.append(avg_iou)
    scheduler.step()

    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, IoU={avg_iou:.2f}%")

    if avg_iou > best_iou:
        best_iou = avg_iou
        print(f"  ✓ Best: {best_iou:.2f}%")

print(f"\n✓ Done! Best IoU: {best_iou:.2f}%")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
axes[0].plot(train_losses, marker='o', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)
axes[1].plot(val_ious, marker='o', color='green', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('IoU (%)')
axes[1].set_title('Validation IoU')
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

model.eval()
num_vis = 10
indices = np.random.choice(len(val_dataset), num_vis, replace=False)

fig, axes = plt.subplots(num_vis, 4, figsize=(16, 4*num_vis))

with torch.no_grad():
    for i, idx in enumerate(indices):
        sample = val_dataset[int(idx)]
        image = sample['image'].unsqueeze(0).to(device)
        input_ids = sample['input_ids'].unsqueeze(0).to(device)
        attention_mask = sample['attention_mask'].unsqueeze(0).to(device)
        gt_mask = sample['mask'].cpu().numpy()

        pred_mask, _ = model(image, input_ids, attention_mask)
        pred_mask_np = pred_mask[0].cpu().numpy()

        img = image[0].cpu().numpy()
        img = img * np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) + np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
        img = np.clip(img.transpose(1, 2, 0), 0, 1)

        base_idx = int(idx)
        if hasattr(val_dataset.base_dataset, 'dataset'):
            actual_idx = val_dataset.base_dataset.indices[base_idx]
            text = val_dataset.base_dataset.dataset[actual_idx]['text']
        else:
            text = val_base[base_idx]['text']

        iou = ((pred_mask_np > 0.5) & (gt_mask > 0.5)).sum() / (((pred_mask_np > 0.5) | (gt_mask > 0.5)).sum() + 1e-8)

        axes[i, 0].imshow(img)
        axes[i, 0].set_title(f'"{text}"', fontsize=9)
        axes[i, 0].axis('off')

        axes[i, 1].imshow(gt_mask, cmap='gray')
        axes[i, 1].set_title('Ground Truth', fontsize=9)
        axes[i, 1].axis('off')

        axes[i, 2].imshow(pred_mask_np, cmap='gray')
        axes[i, 2].set_title(f'Prediction (IoU: {iou:.3f})', fontsize=9)
        axes[i, 2].axis('off')

        overlay = img.copy()
        overlay[pred_mask_np > 0.5] = overlay[pred_mask_np > 0.5] * 0.5 + np.array([1, 0, 0]) * 0.5
        axes[i, 3].imshow(overlay)
        axes[i, 3].set_title('Overlay', fontsize=9)
        axes[i, 3].axis('off')

plt.tight_layout()
plt.show()

print(f"✓ Visualized {num_vis} predictions")

In [None]:
model.eval()
all_ious = []

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(val_loader, desc='Evaluation')):
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        pred_masks, _ = model(images, input_ids, attention_mask)
        iou = compute_iou(pred_masks, masks)
        all_ious.extend(iou.cpu().numpy())

        if batch_idx >= 200:
            break

all_ious = np.array(all_ious)

print("\n" + "="*60)
print("COCO REFERRING EXPRESSION SEGMENTATION RESULTS")
print("="*60)
print(f"Mean IoU:        {all_ious.mean() * 100:.2f}%")
print(f"Median IoU:      {np.median(all_ious) * 100:.2f}%")
print(f"\nPrecision @ IoU:")
for thresh in [0.5, 0.6, 0.7, 0.8, 0.9]:
    prec = (all_ious > thresh).mean() * 100
    print(f"  P@{thresh:.1f}: {prec:.2f}%")
print("="*60)
print(f"Samples: {len(all_ious)}")
print(f"Dataset: COCO 2017 with single-instance referring expressions")

In [None]:
print("\n✅ COMPLETE!")
print(f"Best IoU: {best_iou:.2f}%")
print("\nUsed REAL COCO 2017 validation dataset")
print("Dataset: COCO val2017 with single-instance object filtering")
print("Method: ResNet-50 + BERT + Cross-Modal Transformers")
print("\nNote: Results improved using single-instance filtering")
print("For better results: train 40+ epochs on full dataset with spatial expressions")