In [171]:
import os
import torch
import torchvision
from torchvision import models
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torch import nn
import random
import cv2

In [9]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

device

'cuda'

In [80]:
current_directory = os.getcwd()
DATA_PATH = os.path.join(current_directory, r'archive\Bridge_Crack_Image\DBCC_Training_Data_Set')
DATA_PATH_TRAIN = os.path.join(current_directory, r'archive\Bridge_Crack_Image\DBCC_Training_Data_Set\train')
DATA_PATH_VAL = os.path.join(current_directory, r"archive\Bridge_Crack_Image\DBCC_Training_Data_Set\val")

In [159]:
class DefectDataset(Dataset):

    def __init__(self, img_dir, label_dir, transform) -> None:
        super().__init__()
        self.transform = transform
        self.img_paths = []
        self.img_labels = []
        with open(label_dir) as a:
            for line in a:
                name, label = line.replace('\n', '').split(' ')
                self.img_paths.append(os.path.join(img_dir, name))
                self.img_labels.append(int(label))


    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index) -> tuple:
        img_path = self.img_paths[index]
        image = cv2.imread(img_path)

        if self.transform:
            image = self.transform(image)

        label = self.img_labels[index]

        return (image, label)

In [160]:
transform = v2.Compose([
    v2.ToTensor(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [177]:
from torch.utils.data import random_split

batch_size = 64
dataset = DefectDataset(DATA_PATH_TRAIN, DATA_PATH + r'\train.txt', transform)
train_dataset, test_dataset = random_split(dataset, [0.8, 0.2])
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
val_dataset = DefectDataset(DATA_PATH_VAL, DATA_PATH + r'\val.txt', transform)
val_loader = DataLoader(val_dataset, batch_size=180)

In [172]:
model = models.resnet50(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512), 
    nn.ReLU(), 
    nn.Linear(512, 64), 
    nn.ReLU(), 
    nn.Linear(64, 1), 
    nn.Sigmoid()
    )
for param in model.fc.parameters():
    param.requires_grad = True
print(model)
resnet = model.to(device)

loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [180]:
def val(model, val_loader):
    model.eval()
    tp = 0
    total = 0

    for x,y in val_loader:
        y = y.to(device)
        pre = model(x.to(device))
        pred = torch.Tensor([1 if i>0.5 else 0 for i in pre])
        tp += int(torch.eq(torch.Tensor(pred),y.to('cpu')).sum())
        total +=len(pred)
    
    return tp/total

def train(dataloader, model, loss_fn, op_fn, epoch):
    for ep in range(epoch):
        for step, (x, y) in enumerate(dataloader):

            (x, y) = (x.to(device), y.to(device))
            y_pred = torch.flatten(model.forward(x))
            y_pred = y_pred.to(device)
            loss = loss_fn(y_pred, y)

            op_fn.zero_grad()
            loss.backward()
            op_fn.step()

            if step%100 == 0:
                loss, current = loss.item(), (ep+1)*(step+1)*batch_size
                print(f"loss = {loss}, samples = {current}")

        acc = val(model, val_loader)
        print(f'accuracy {acc}:.2f')

In [181]:
train(train_loader, resnet, loss_function, optimizer, 20)

  return F.conv2d(input, weight, bias, self.stride,


loss = 0.6704172492027283, samples = 64
loss = 0.11877257376909256, samples = 6464
loss = 0.1741468608379364, samples = 12864
loss = 0.17040714621543884, samples = 19264
loss = 0.06314453482627869, samples = 25664
loss = 0.060732390731573105, samples = 32064
loss = 0.13440397381782532, samples = 38464
accuracy 0.9634:.2f
loss = 0.17077016830444336, samples = 128
loss = 0.06731943041086197, samples = 12928
loss = 0.040667831897735596, samples = 25728
loss = 0.09303343296051025, samples = 38528
loss = 0.043540388345718384, samples = 51328
loss = 0.054530054330825806, samples = 64128
loss = 0.050152674317359924, samples = 76928
accuracy 0.9698:.2f
loss = 0.08747145533561707, samples = 192
loss = 0.061659570783376694, samples = 19392
loss = 0.023392487317323685, samples = 38592
loss = 0.08006488531827927, samples = 57792
loss = 0.048498958349227905, samples = 76992
loss = 0.036655500531196594, samples = 96192
loss = 0.03566819801926613, samples = 115392
accuracy 0.9734:.2f
loss = 0.0480879

KeyboardInterrupt: 

In [188]:
predicts = []
labels = []

for x,y in test_loader:
    y = y.to(device)
    pre = resnet(x.to(device))
    pred = [1 if i>0.5 else 0 for i in pre]
    predicts.append(pred)
    labels.append(y)

  return F.conv2d(input, weight, bias, self.stride,


In [216]:
a = [i.tolist() for i in labels]

In [229]:
true = 0
count = 0
for x,y in zip(a,predicts):
    for xi,yi in zip(x,y):
        if xi==yi:
            true += 1
        count += 1
print(true/count)

0.9781
