In [23]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader

import csv
import time
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

grid_size = 11
num_classes = 2  # 2 classes just object and not object
data_length = 875 # starting at 0
train_length = 700

epochs = 100
batch_size = 32

### Model Definition

In [16]:
class MinecraftFoundationModel(nn.Module):
    def __init__(self):
        super(MinecraftFoundationModel, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])  # Remove the last two layers
        
        # Object Prediction Head
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1),
            nn.Upsample(size=(grid_size * grid_size), mode='bilinear', align_corners=True),
            nn.Softmax(dim=1)
        )
        
        # Depth Estimation Head
        self.depth_estimation_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2048 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, 121)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        segmentation = self.segmentation_head(features)
        depth = self.depth_estimation_head(features)

        return depth

### Dataset Definition

In [19]:
class MinecraftDepthDataset(Dataset):
    def __init__(self, image_paths, depth_paths, transform=None, train=True):
        self.image_paths = image_paths
        self.depth_paths = depth_paths
        self.transform = transform
        self.trainset = train

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx])
        with open(self.depth_paths[idx], mode='r') as file:
            csv_reader = csv.reader(file)
            header = next(csv_reader)
            rows = list(csv_reader)
            distances = [float(row[1]) for row in rows]

        if self.transform:
            image = self.transform(image)
        
        depth = torch.tensor(distances, dtype=torch.float32)

        return image, depth

weights = models.ResNet50_Weights.DEFAULT
transform = weights.transforms()

image_paths = ['data/Unannotated/image_{}.png'.format(str(i).zfill(5)) for i in range(1, data_length)]
depth_paths = ['data/Annotations/annotation_{}.csv'.format(str(i).zfill(5)) for i in range(1, data_length)]

# Example paths, replace with your actual paths
train_dataset = MinecraftDepthDataset(image_paths[:train_length], depth_paths[:train_length], transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

val_dataset = MinecraftDepthDataset(image_paths[train_length:], depth_paths[train_length:], transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

### Training

In [26]:
model = MinecraftFoundationModel().to(device)

depth_loss_fn = nn.MSELoss()
optimizer_backbone = torch.optim.Adam(model.backbone.parameters(), lr=1e-4)
optimizer_depth = torch.optim.Adam(model.depth_estimation_head.parameters(), lr=1e-3)

train_losses = []
val_losses = []

for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    
    # Training step
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch+1}/{epochs}', unit='batch') as pbar:
        start_time = time.time()
        for i, (images, depth_targets) in enumerate(train_loader):
            images = images.to(device)
            depth_targets = depth_targets.to(device)

            optimizer_backbone.zero_grad()
            optimizer_depth.zero_grad()

            depth_pred = model(images)

            loss = depth_loss_fn(depth_pred, depth_targets)
            loss.backward()
            optimizer_depth.step()
            optimizer_backbone.step()

            train_loss += loss.item() * images.size(0)

            pbar.set_postfix({'loss': loss.item(), 'iter/s': '{:.2f}'.format(i / (time.time() - start_time))})
            pbar.update(1)
    
    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss)

    # Validation step
    model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for images, depth_targets in val_loader:
            images = images.to(device)
            depth_targets = depth_targets.to(device)

            depth_pred = model(images)
            loss = depth_loss_fn(depth_pred, depth_targets)
            val_loss += loss.item() * images.size(0)

    val_loss /= len(val_loader.dataset)
    val_losses.append(val_loss)

    print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

Epoch 1/100: 100%|██████████| 22/22 [00:25<00:00,  1.18s/batch, loss=30.3, iter/s=0.81]


Epoch 1/100, Train Loss: 86.7518, Val Loss: 26.1852


Epoch 2/100:  23%|██▎       | 5/22 [00:06<00:22,  1.32s/batch, loss=17, iter/s=0.70]  


KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()