In [1]:
# Install dependencies
!pip install -q kaggle transformers torchvision

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# Upload your Kaggle API token (kaggle.json)
from google.colab import files
files.upload()  # Upload kaggle.json

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

Saving kaggle.json to kaggle.json


In [3]:
# Download dataset
!kaggle datasets download -d humansintheloop/teeth-segmentation-on-dental-x-ray-images
!unzip -q teeth-segmentation-on-dental-x-ray-images.zip -d teeth_data

Dataset URL: https://www.kaggle.com/datasets/humansintheloop/teeth-segmentation-on-dental-x-ray-images
License(s): CC0-1.0


In [25]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
from tqdm import tqdm
from google.colab import files

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
# Dataset class
class TeethSegDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir):
        self.img_dir = os.path.join(root_dir, 'Teeth Segmentation PNG', 'd2', 'img')
        self.mask_dir = os.path.join(root_dir, 'Teeth Segmentation PNG', 'd2', 'masks_human')
        self.img_list = sorted(os.listdir(self.img_dir))
        self.transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_list[idx])
        mask_path = os.path.join(self.mask_dir, self.img_list[idx].replace(".jpg", ".png"))

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.transform(image)
        mask = self.transform(mask)
        mask = (mask > 0.5).float()

        return image, mask

In [28]:
# ViT-UNet model
class ViTUNet(nn.Module):
    def __init__(self):
        super(ViTUNet, self).__init__()
        config = ViTConfig(image_size=224, patch_size=16, num_channels=3, hidden_size=768,
                           num_attention_heads=12, num_hidden_layers=6)
        self.encoder = ViTModel(config)

        self.decoder = nn.Sequential(
            nn.Conv2d(768, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
      x = self.encoder(pixel_values=x).last_hidden_state
      b, n, c = x.shape
      x = x[:, 1:, :].permute(0, 2, 1).reshape(b, c, 14, 14)
      x = self.decoder(x)

      # Upsample to match the ground truth mask size (224x224)
      x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

      return x



In [29]:
# Load dataset
root_path = "teeth_data"  # after unzipping
dataset = TeethSegDataset(root_path)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [30]:
# Initialize model, loss, optimizer
model = ViTUNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [31]:
# Training loop
epochs = 5
model.train()
for epoch in range(epochs):
    running_loss = 0.0
    for images, masks in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {running_loss / len(dataloader):.4f}")

Epoch 1/5: 100%|██████████| 75/75 [12:43<00:00, 10.18s/it]


Epoch 1 Loss: 0.5152


Epoch 2/5: 100%|██████████| 75/75 [12:32<00:00, 10.03s/it]


Epoch 2 Loss: 0.4269


Epoch 3/5: 100%|██████████| 75/75 [12:21<00:00,  9.88s/it]


Epoch 3 Loss: 0.4064


Epoch 4/5: 100%|██████████| 75/75 [12:24<00:00,  9.93s/it]


Epoch 4 Loss: 0.3914


Epoch 5/5: 100%|██████████| 75/75 [12:31<00:00, 10.02s/it]

Epoch 5 Loss: 0.3592





In [32]:
# Save model
model_path = "vit_teeth_segmentation.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved as {model_path}")

Model saved as vit_teeth_segmentation.pth


In [33]:
# Download to local machine
files.download("vit_teeth_segmentation.pth")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>