In [1]:
import torch
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class SatDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace("sat.jpg", "mask.png"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [None]:
!pip install -q kaggle

In [None]:
!mkdir ~/.kaggle

In [2]:
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_set = SatDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,

    )

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_set = SatDataset(
        image_dir = val_dir,
        mask_dir= val_maskdir,
        transform= val_transform
    )

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [3]:
import torch.nn as nn


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, X):
        return self.conv(X)

In [4]:
class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs =nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(in_channels=feature*2, out_channels= feature, kernel_size=2, stride=2)) # multiply feature by 2 to account for skip connection
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(in_channels=features[-1], out_channels=features[-1]*2)
        self.final=  nn.Conv2d(features[0], out_channels=out_channels, kernel_size=1)

    def forward(self, X):
        skip_connections = []
        for i, down in enumerate(self.downs):
            X = down(X)
            skip_connections.append(X)
            X = self.pool(X)


        X = self.bottleneck(X)
        skip_connections = list(reversed(skip_connections))


        for i in range(0, len(self.ups), 2):
            X = self.ups[i](X)
            skip_conn = skip_connections[i//2]
            #if X.shape != skip_conn.shape:
                #X = TF.resize(X, size=skip_conn.shape[2:])

            concat_skip = torch.cat((skip_conn,X), dim=1)
            X = self.ups[i+1](concat_skip)

        finalX = self.final(X)
        return self.final(X)

In [5]:
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(loader)}")
    model.train()

In [6]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device='cuda')
        targets = targets.float().unsqueeze(1).to(device='cuda')


        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [9]:
import torchvision

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def save_predictions_as_imgs(
    loader, model, folder="saved_images/", device="cuda"
):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()

In [14]:
import torch.optim as optim
from tqdm import tqdm


LEARNING_RATE = 1e-4
NUM_EPOCHS= 5
LOAD_CHECKPOINT= TRUE

model = UNET(in_channels=3, out_channels=1).to('cuda')
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

if LOAD_CHECKPOINT:
        print("loading checkpoints")
        load_checkpoint(torch.load("best_checkpoints.tar"), model)



check_accuracy(val_loader, model, device='cuda')
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device='cuda')

        # print some examples to a folder
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images", device='cuda'
        )



NameError: name 'FALSE' is not defined

In [None]:
from PIL import Image
from torchvision.transforms import v2


image = Image.open('vit1 (1).jpg')



trans = v2.PILToTensor()
img_tensor = trans(image)

import torch
from torchvision import transforms
import matplotlib.pyplot as plt

%matplotlib inline

transt = transforms.ToTensor()
transp = transforms.ToPILImage()
img_t = transt(Image.open('vit1 (1).jpg'))

#torch.Tensor.unfold(dimension, size, step)
#slices the images into 8*8 size patches
patches = img_t.data.unfold(0, 3, 3).unfold(1, 8, 8).unfold(2, 8, 8)

In [None]:
import torch

# Your original tensor
original_tensor = torch.randn(3, 821, 1753)

# Specify block size
block_size = (256, 256)

# Use unfold to create a view of the tensor as blocks
unfolded_tensor = original_tensor.unfold(1, block_size[0], block_size[0]).unfold(2, block_size[1], block_size[1])

# Get the size of the unfolded tensor
unfolded_size = unfolded_tensor.size()

# Reshape the unfolded tensor to get the final result
result_tensor = unfolded_tensor.reshape(unfolded_size[0], unfolded_size[1], unfolded_size[2], block_size[0], block_size[1])

# Check the size of the result tensor
print(result_tensor.size())
