In [1]:
import torch
print(torch.cuda.is_available())  # Should print True if GPU is active


True


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os

# SRCNN Model
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.layer1 = nn.Conv2d(1, 64, kernel_size=9, padding=4)
        self.layer2 = nn.Conv2d(64, 32, kernel_size=1)
        self.layer3 = nn.Conv2d(32, 1, kernel_size=5, padding=2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x)
        return x

# Dataset class for X-rays
class XrayDataset(Dataset):
    def __init__(self, high_res_dir, transform=None, downsample_factor=2):
        self.high_res_dir = high_res_dir
        self.transform = transform
        self.downsample_factor = downsample_factor
        self.file_names = [f for f in os.listdir(high_res_dir) if f.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        img_name = self.file_names[idx]
        hr_image = Image.open(os.path.join(self.high_res_dir, img_name)).convert('L')

        # Downsample and then upsample to create low-res
        lr_image = hr_image.resize((hr_image.width // self.downsample_factor, hr_image.height // self.downsample_factor), Image.BICUBIC)
        lr_image = lr_image.resize((hr_image.width, hr_image.height), Image.BICUBIC)

        if self.transform:
            hr_image = self.transform(hr_image)
            lr_image = self.transform(lr_image)

        return lr_image, hr_image

# Transform to tensor only
transform = transforms.Compose([transforms.Resize((1024, 1280)), transforms.ToTensor()])

# Example for data loading, use your path and batch size
dataset = XrayDataset('high_res_images', transform=transform, downsample_factor=2)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Initialize model, loss, optimizer
model = SRCNN()
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5



In [None]:

# Training loop example snippet
for epoch in range(num_epochs):
    for lr_imgs, hr_imgs in dataloader:
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)
        
        optimizer.zero_grad()
        outputs = model(lr_imgs)
        loss = criterion(outputs, hr_imgs)
        loss.backward()
        optimizer.step()

# Save the trained model weights
torch.save(model.state_dict(), 'srcnn_xray.pth')