# Imports

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

In [None]:
class CrackRemovalNet(nn.Module):
    def __init__(self):
        super(CrackRemovalNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x


class CrackDataset(Dataset):
    def __init__(self, root_dir, target_size=(256, 256)):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg')]
        self.target_size = target_size

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        target_name = os.path.join(self.root_dir, self.image_files[idx].replace('_old', ''))
        target = Image.open(target_name).convert('RGB')

        transform = transforms.Compose([
            transforms.Resize(self.target_size),
            transforms.RandomRotation(degrees=15),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.ToTensor(),
        ])

        image = transform(image)
        target = transform(target)

        return image, target


def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')


root_dir = '/content/drive/MyDrive/new_images'
dataset = CrackDataset(root_dir)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

model = CrackRemovalNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_model(model, train_loader, criterion, optimizer, num_epochs=20)
torch.save(model.state_dict(), '/content/drive/MyDrive/model_crapaturi_bun.pth')

In [None]:
class CrackRemovalNet(nn.Module):
    def __init__(self):
        super(CrackRemovalNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x

class TestDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')

        transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        image = transform(image)

        return image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = CrackRemovalNet()
model = model.to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/model_crapaturi_bun.pth', map_location=device))
model.eval()

test_root_dir = '/content/drive/MyDrive/test_images'
test_dataset = TestDataset(test_root_dir)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


with torch.no_grad():
    for i, inputs in enumerate(test_loader):
        inputs = inputs.to(device)
        outputs = model(inputs)

        output_image = transforms.ToPILImage()(outputs.squeeze(0).cpu())

        plt.subplot(1, 2, 1)
        plt.imshow(inputs.squeeze(0).permute(1, 2, 0).cpu())
        plt.title('Original Image')
        plt.subplot(1, 2, 2)
        plt.imshow(output_image)
        plt.title('Processed Image')
        plt.show()

        original_size = inputs.size()[2:]
        output_image = transforms.ToPILImage()(outputs.squeeze(0).cpu())
        output_image = output_image.resize(original_size[::-1], Image.BICUBIC)
        output_image_path = os.path.join(test_root_dir, f'processed_image_{i}.jpg')
        output_image.save(output_image_path)

        torch.cuda.empty_cache()