In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np

class ImagePairDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.image_pairs = os.listdir(data_dir)[:64]
        
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, index):
        pair_dir = os.path.join(self.data_dir, self.image_pairs[index])
        angle_file = os.path.join(pair_dir, 'angle.txt')
        with open(angle_file, 'r') as f:
            angle = float(f.read().strip())
        
        bg_path = os.path.join(pair_dir, 'bg.jpg')
        bg_image = Image.open(bg_path).convert("L")
        cut_path = os.path.join(pair_dir, 'cut.jpg')
        cut_image = Image.open(cut_path).convert("L")
        
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
        
        bg_tensor = transform(bg_image)
        cut_tensor = transform(cut_image)
        
        return bg_tensor, cut_tensor, torch.tensor(angle, dtype=torch.float)

class ImagePairModel(nn.Module):
    def __init__(self):
        super(ImagePairModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 7396, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        
        self.flatten = nn.Flatten()
        
    def forward(self, bg, cut):
        x1 = self.cnn(bg)
        x2 = self.cnn(cut)
        x = torch.add(x1, x2)
        x = x.view(x.size(0), -1)
        x = self.flatten(x)
        x = self.fc(x)
        return x
    
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    train_loss = 0.0
    for bg, cut, angle in train_loader:
        bg, cut, angle = bg.to(device), cut.to(device), angle.to(device)
        optimizer.zero_grad()
        outputs = model(bg, cut)
        loss = criterion(outputs.squeeze(), angle)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    print('Train Loss: {:.6f}'.format(train_loss))
    
def test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for bg, cut, angle in test_loader:
            bg, cut, angle = bg.to(device), cut.to(device), angle.to(device)
            outputs = model(bg, cut)
            loss = criterion(outputs.squeeze(), angle)
            test_loss += loss.item()
    test_loss /= len(test_loader)
    print('Test Loss: {:.6f}'.format(test_loss))
    
def main():
    # Hyperparameters
    batch_size = 32
    learning_rate = 0.001
    num_epochs = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Datasets and data loaders
    train_dataset = ImagePairDataset('/tmp/train')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = ImagePairDataset('/tmp/test')
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Model, loss function, and optimizer
    model = ImagePairModel().to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Training loop
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        train(model, train_loader, criterion, optimizer, device)
        test(model, test_loader, criterion, device)
    
    # Save the model
    torch.save(model.state_dict(), 'model.pth')
    
# print(ImagePairModel())
main()

# train_dataset = ImagePairDataset('/tmp/train')
# bg, cut, angle = train_dataset[0]
# m = ImagePairModel().forward(bg, cut)
# print(m.shape)

# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# for bg, cut, angle in train_loader:
#     m = ImagePairModel().forward(bg, cut)
#     print("m.shape", m.shape)
#     break

Epoch 1/10


: 

: 