In [1]:
import os
import cv2
import torch
import clip
import numpy as np
from tqdm import tqdm
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from transformers import AutoConfig, AutoModel
from torch.utils.data import Dataset, DataLoader
from PIL import Image


In [2]:
#!git clone https://github.com/Beckschen/TransUNet.git
#!pip install tensorboard tensorboardX ml-collections medpy SimpleITK scipy h5py

In [3]:
import sys
sys.path.append("TransUNet")

In [4]:
from networks.vit_seg_modeling import VisionTransformer

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [6]:
device

device(type='cuda')

In [7]:
#!wget https://download.pytorch.org/models/resnet50-0676ba61.pth

In [8]:
CONFIG = {
    'dataset_path': './cityscapes',  # Replace with actual Cityscapes dataset path
    'transunet_config': {
        'img_size': 512,
        'in_channels': 3,
        'num_classes': 35,
        'hidden_size': 768,
        'num_attention_heads': 12,
        'classifier': 'seg',
        'n_skip': 3,
        'vit_name': 'R50-ViT-B_16',
        'patches': {"grid": (16, 16)},
        'block_units': [3, 4, 6, 3],
        'width_factor': 4,
        'pretrained_path': './pretrained/resnet50.pth'
        },
    'device': 'cuda',
    'batch_size': 4,
    'num_epochs': 10,
    'clip_threshold': 0.82,
    'image_size': (512, 512)
}

In [9]:
class CityscapesData:
    def __init__(self):
        self.item_a = {
            'image': f"{CONFIG['dataset_path']}/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png",
            'mask': f"{CONFIG['dataset_path']}/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelIds.png"
        }
        
        self.item_c = self._load_unlabeled()
        self.item_d = self._load_test_set()
    
    def _load_unlabeled(self):
        paths = []
        base_path = f"{CONFIG['dataset_path']}/leftImg8bit/test/"
        for city in os.listdir(base_path)[:2]:  # Limit for demo
            city_path = os.path.join(base_path, city)
            paths.extend([os.path.join(city_path, f) 
                         for f in os.listdir(city_path)[:50]])
        return paths
    
    def _load_test_set(self):
        paths = []
        base_path = f"{CONFIG['dataset_path']}/leftImg8bit/val"
        for city in os.listdir(base_path)[:1]:  # Limit for demo
            city_path = os.path.join(base_path, city)
            paths.extend([os.path.join(city_path, f) 
                         for f in os.listdir(city_path)[:50]])
        return paths



In [10]:
data = CityscapesData()


In [11]:
class PseudolabelGenerator:
    def __init__(self):
        self.sam = sam_model_registry['vit_h'](
            checkpoint='sam_vit_h_4b8939.pth'
        ).to(CONFIG['device'])
        self.clip_model, self.preprocess = clip.load(
            'ViT-B/32', device=CONFIG['device'])
        self.mask_generator = SamAutomaticMaskGenerator(self.sam)
        
    def _clip_similarity(self, image_patch):
        # Convert numpy array to PIL Image
        if isinstance(image_patch, np.ndarray):
            if image_patch.shape[-1] == 3:  # BGR to RGB
                image_patch = cv2.cvtColor(image_patch, cv2.COLOR_BGR2RGB)
            image_patch = Image.fromarray(image_patch.astype('uint8'))
            
        image_tensor = self.preprocess(image_patch).unsqueeze(0).to(CONFIG['device'])
        text = clip.tokenize(["Urban street scene with vehicles, pedestrians, and road infrastructure"]).to(CONFIG['device'])
        
        with torch.no_grad():
            image_features = self.clip_model.encode_image(image_tensor)
            text_features = self.clip_model.encode_text(text)
            
        return torch.cosine_similarity(image_features, text_features).item()
    
    def generate(self, image_paths):
        pseudo_labels = []
        
        for path in tqdm(image_paths, desc="Generating pseudo-labels"):
            image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
            masks = self.mask_generator.generate(image)
            
            valid_masks = []
            for mask in masks:
                x1, y1, w, h = mask['bbox']
                if w == 0 or h == 0:
                    continue  # Skip invalid masks
                    
                x2 = x1 + w
                y2 = y1 + h
                patch = image[y1:y2, x1:x2]
                
                # Handle empty patches
                if patch.size == 0:
                    continue
                
                similarity = self._clip_similarity(patch)
                if similarity > CONFIG['clip_threshold']:
                    valid_masks.append(mask['segmentation'])
            
            if valid_masks:
                combined_mask = np.zeros(image.shape[:2], dtype=np.uint8)
                for m in valid_masks:
                    combined_mask[m] = 1
                pseudo_labels.append((path, combined_mask))
        
        return pseudo_labels


In [12]:
pl_generator = PseudolabelGenerator()
c_prime = pl_generator.generate(data.item_c[:20])  

Generating pseudo-labels: 100%|█████████████████████████████████████████████████████████| 20/20 [02:11<00:00,  6.56s/it]


In [20]:
pl_generator.generate(data.item_c[:20])  

Generating pseudo-labels:  30%|█████████████████▍                                        | 6/20 [00:39<01:31,  6.52s/it]


KeyboardInterrupt: 

In [13]:
class SegmentationDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        path, mask = self.data[idx]
        image = cv2.imread(path)
        image = cv2.resize(image, CONFIG['image_size'])
        image = torch.tensor(image).permute(2,0,1).float() / 255.0
        
        mask = cv2.resize(mask, CONFIG['image_size'])
        return image, torch.tensor(mask).long()

In [14]:
class TransUNetConfig:
    def __init__(self, **kwargs):
        # Existing parameters
        self.img_size = kwargs.get('img_size', 512)
        self.num_classes = kwargs.get('num_classes', 35)
        self.hidden_size = kwargs.get('hidden_size', 768)
        
        # ResNet backbone parameters
        self.resnet = {
            'block_units': kwargs.get('block_units', [3, 4, 6, 3]),  # For ResNet-50
            'width_factor': kwargs.get('width_factor', 4),
            'pretrained_path': kwargs.get('pretrained_path', 'pretrained/resnet50.pth')
        }
        
        # Transformer parameters
        self.patches = {'grid': (16, 16)}
        self.n_skip = 3
        self.classifier = 'seg'


In [15]:
def train_model(train_data):
    from networks.vit_seg_modeling_resnet_skip import ResNetV2
    
    config = TransUNetConfig(**CONFIG['transunet_config'])
    
    # Initialize ResNet with corrected parameters
    resnet = ResNetV2(
        block_units=config.resnet['block_units'],
        width_factor=config.resnet['width_factor']
    )
    
    # Load pretrained weights if available
    if os.path.exists(config.resnet['pretrained_path']):
        resnet.load_state_dict(torch.load(config.resnet['pretrained_path']))
    
    # Rest of model initialization...


    # Initialize VisionTransformer with ResNet
    model = VisionTransformer(
        config=config,
        img_size=config.img_size,
        num_classes=config.num_classes,
        resnet=resnet
    ).to(CONFIG['device'])

    # Training setup
    dataset = SegmentationDataset(train_data)
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
    criterion = torch.nn.CrossEntropyLoss()
    
    # Training loop
    for epoch in range(CONFIG['num_epochs']):
        model.train()
        epoch_loss = 0.0  # Initialize epoch loss
        
        for images, masks in tqdm(loader, desc=f"Epoch {epoch+1}"):
            images = images.to(CONFIG['device'])
            masks = masks.to(CONFIG['device'])

            outputs = model(images)
            loss = criterion(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1} Loss: {epoch_loss/len(loader):.4f}")
    
    return model


In [16]:
model = train_model(c_prime)

TypeError: VisionTransformer.__init__() got an unexpected keyword argument 'resnet'

In [21]:
def evaluate_model(model, test_paths):
    model.eval()
    ious = []
    
    for path in tqdm(test_paths, desc="Evaluating"):
        # Load and prepare image
        image = cv2.imread(path)
        image = cv2.resize(image, CONFIG['image_size'])
        image_tensor = torch.tensor(image).permute(2,0,1).float().unsqueeze(0) / 255.0
        image_tensor = image_tensor.to(CONFIG['device'])
        
        with torch.no_grad():
            output = model(image_tensor)  # Remove .logits if unnecessary
            pred_mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()
        
        gt_path = path.replace('leftImg8bit', 'gtFine').replace('.png', '_gtFine_labelIds.png')
        gt_mask = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
        gt_mask = cv2.resize(gt_mask, CONFIG['image_size'])
        
        intersection = np.logical_and(pred_mask == gt_mask, gt_mask > 0)
        union = np.logical_or(pred_mask > 0, gt_mask > 0)
        
        if np.sum(union) > 0:  
            ious.append(np.sum(intersection) / np.sum(union))
    
    return {
        'mean_iou': np.mean(ious),
        'std_iou': np.std(ious),
        'max_iou': np.max(ious),
        'min_iou': np.min(ious)
    }


In [None]:
results = evaluate_model(model, data.item_d[:10])  # Use subset for demo
print("\nEvaluation Results:")
for k, v in results.items():
    print(f"{k:8}: {v:.4f}")

In [None]:
import matplotlib.pyplot as plt

def visualize_sample(image_path, pred_mask, gt_mask):
    image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
    image_resized = cv2.resize(image, CONFIG['image_size'])
    
    fig, ax = plt.subplots(1, 3, figsize=(15,5))
    
    ax[0].imshow(image_resized)
    ax[0].set_title('Input Image')
    
    ax[1].imshow(pred_mask)
    ax[1].set_title('Predicted Mask')
    
    ax[2].imshow(gt_mask)
    ax[2].set_title('Ground Truth')
    
    plt.show()


In [None]:
test_image = data.item_d[0]
pred_mask = model(torch.tensor(cv2.imread(test_image)).permute(2,0,1).float().unsqueeze(0).to(CONFIG['device'])/255.0).logits.argmax(1).squeeze().cpu().numpy()
gt_mask = cv2.imread(test_image.replace('leftImg8bit', 'gtFine').replace('.png', '_gtFine_labelIds.png'), cv2.IMREAD_GRAYSCALE)
gt_mask = cv2.resize(gt_mask, CONFIG['image_size'])

In [None]:
visualize_sample(test_image, pred_mask, gt_mask)