In [1]:
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from PIL import Image
import os
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large as deeplab

In [2]:
fd = "/home/jupyter/ai_font/data/exp0717/train_seg/raw_assembled"

In [3]:
class myDS(Dataset):
    def __init__(self, fd, train=True):
        self.path_list = []
        for root, dirs, files in os.walk(fd):
            for file in files:
                if file.endswith('.png') & ("checkpoint" not in file):
                    image_path = os.path.join(root, file)
                    label_root = root.replace("raw_assembled","raw_label")
                    label_file = file.replace(".png",".npy")
                    label_path = os.path.join(label_root, label_file)
                    self.path_list.append((image_path,label_path,))
        self.transforms = transforms.Compose([
                transforms.Resize(520),
                transforms.ToTensor(),  # Rescales to [0.0, 1.0]
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        self.label_transforms = transforms.Resize(520)
        
        if train:
            self.path_list = [p for i, p in enumerate(sorted(self.path_list)) if (i % 100) != 0]
        else:
            self.path_list = [p for i, p in enumerate(sorted(self.path_list)) if (i % 100) == 0]

    def __len__(self):
        return len(self.path_list)
    
    def __getitem__(self, idx):
        image_path, label_path = self.path_list[idx]
        image = Image.open(image_path).convert("RGB")
        mask = np.load(label_path)
        
        return self.transforms(image), self.label_transforms(torch.from_numpy(mask))

In [4]:
def calculate_accuracy(preds, targets):

    # Convert predictions to class labels by taking the argmax along the class dimension
    pred_labels = torch.argmax(preds, dim=1)
    target_labels = torch.argmax(targets, dim=1)
    
    # Check if the predicted labels are equal to the target labels
    correct = (pred_labels == target_labels).float()
    
    # Calculate the accuracy
    accuracy = correct.sum() / correct.numel()
    
    return accuracy.item() * 100  # Convert to percentage

In [5]:
ds = myDS("/home/jupyter/ai_font/data/exp0717/train_seg/raw_assembled", train=True)
dl = DataLoader(ds, batch_size=32, shuffle=True)
testds = myDS("/home/jupyter/ai_font/data/exp0717/train_seg/raw_assembled", train=False)
testdl = DataLoader(testds, batch_size=32, shuffle=False)

In [6]:
model = deeplab(weights_backbone="DEFAULT", num_classes=4)
model = model.cuda()

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [8]:
epoch = 100

In [None]:
pbar = tqdm(range(epoch))
for i in pbar:
    train_acc = []
    test_acc = []
    model.train()
    for j, data in enumerate(dl):
        optimizer.zero_grad()
        x, y = data
        x = x.cuda()
        y = y.cuda()
        pred = model(x)
        loss = criterion(pred['out'], y)
        loss.backward()
        optimizer.step()
        train_acc.append(calculate_accuracy(pred['out'], y))
        pbar.set_postfix(train=f"{j}/{len(dl)}")
    model.eval()
    for j, data in enumerate(testdl):
        x, y = data
        x = x.cuda()
        y = y.cuda()
        pred = model(x)
        test_acc.append(calculate_accuracy(pred['out'], y))
        pbar.set_postfix(test=f"{j}/{len(testdl)}")
    print(f"{i}: {np.mean(train_acc)} / {np.mean(test_acc)}")
    torch.save(model.state_dict(), f"model_{i}.pth")

  1%|          | 1/100 [17:18<28:33:22, 1038.41s/it, test=10/11]

0: 97.45128085344425 / 98.91342520713806


  2%|▏         | 2/100 [34:20<28:00:17, 1028.75s/it, test=10/11]     

1: 99.07898655550316 / 99.15885979479009


  3%|▎         | 3/100 [51:41<27:52:22, 1034.46s/it, test=10/11]     

2: 99.25512999241101 / 99.24615784124894


  4%|▍         | 4/100 [1:09:04<27:40:22, 1037.74s/it, test=10/11]     

3: 99.3292336983954 / 99.37049150466919


  5%|▌         | 5/100 [1:26:28<27:26:29, 1039.89s/it, test=10/11]     

4: 99.36787898681521 / 99.39074353738265


  6%|▌         | 6/100 [1:43:47<27:09:07, 1039.87s/it, test=10/11]     

5: 99.38742787802727 / 99.3605527010831


  7%|▋         | 7/100 [2:01:09<26:52:29, 1040.32s/it, test=10/11]     

6: 99.41379938019843 / 99.41582408818331


  8%|▊         | 8/100 [2:18:18<26:29:57, 1036.93s/it, test=10/11]     

7: 99.4274049226105 / 99.43248521197926


  9%|▉         | 9/100 [2:35:42<26:15:43, 1038.94s/it, test=10/11]     

8: 99.43993269515786 / 99.44910623810507


 10%|█         | 10/100 [2:53:07<26:01:13, 1040.82s/it, test=10/11]    

9: 99.44969801827853 / 99.40828084945679


 11%|█         | 11/100 [3:10:28<25:44:12, 1041.04s/it, test=10/11]     

10: 99.45268388596568 / 99.46125420657071


 12%|█▏        | 12/100 [3:27:50<25:26:56, 1041.09s/it, test=10/11]     

11: 99.46574737316139 / 99.46190443905917


 13%|█▎        | 13/100 [3:45:12<25:10:11, 1041.51s/it, test=10/11]     

12: 99.46962547610735 / 99.46669990366155


 14%|█▍        | 14/100 [4:02:21<24:47:24, 1037.73s/it, test=10/11]     

13: 99.47404819583717 / 99.48311285539107


 15%|█▌        | 15/100 [4:19:33<24:27:45, 1036.06s/it, test=10/11]     

14: 99.4738831482621 / 99.48139624162154


 16%|█▌        | 16/100 [4:36:54<24:12:40, 1037.62s/it, test=10/11]     

15: 99.48386393501225 / 99.48683814568953


 17%|█▋        | 17/100 [4:54:18<23:57:48, 1039.37s/it, test=10/11]     

16: 99.48685778487412 / 99.37734983184122


 18%|█▊        | 18/100 [5:11:40<23:41:28, 1040.11s/it, test=10/11]     

17: 99.48904569951091 / 99.46140105074102


 19%|█▉        | 19/100 [5:28:40<23:16:04, 1034.13s/it, test=10/11]     

18: 99.49352541273932 / 99.49077692898837


 20%|██        | 20/100 [5:45:55<22:59:04, 1034.31s/it, test=10/11]     

19: 99.49659765758268 / 99.45984970439564


 21%|██        | 21/100 [6:03:08<22:41:34, 1034.11s/it, test=10/11]     

20: 99.49877135414293 / 99.46082505312833


 22%|██▏       | 22/100 [6:20:37<22:29:52, 1038.36s/it, test=10/11]     

21: 99.48884546756744 / 99.49671084230596


 23%|██▎       | 23/100 [6:38:00<22:14:30, 1039.87s/it, test=10/11]     

22: 99.50160736737983 / 99.50143532319503


 24%|██▍       | 24/100 [6:55:22<21:57:49, 1040.39s/it, test=10/11]     

23: 99.50402942029915 / 99.47370995174755


 25%|██▌       | 25/100 [7:12:21<21:32:44, 1034.19s/it, test=10/11]     

24: 99.5058732916379 / 99.51123053377324


 26%|██▌       | 26/100 [7:29:45<21:19:05, 1037.10s/it, test=10/11]     

25: 99.50754917373058 / 99.49896173043685


 27%|██▋       | 27/100 [7:46:56<20:59:39, 1035.33s/it, test=10/11]     

26: 99.50337945974248 / 99.47669343514876


 28%|██▊       | 28/100 [8:04:18<20:44:43, 1037.27s/it, test=10/11]     

27: 99.50600363410555 / 99.51139146631414


 29%|██▉       | 29/100 [8:21:18<20:21:23, 1032.16s/it, test=10/11]     

28: 99.51024078617695 / 99.50415275313638


 30%|███       | 30/100 [8:38:40<20:07:29, 1035.00s/it, test=10/11]     

29: 99.51179975061892 / 99.50035160238093


 31%|███       | 31/100 [8:55:57<19:50:59, 1035.65s/it, test=10/11]     

30: 99.51336887319074 / 99.49642799117349


 32%|███▏      | 32/100 [9:13:18<19:35:27, 1037.17s/it, test=10/11]     

31: 99.51448404194026 / 99.49423399838534


 33%|███▎      | 33/100 [9:30:40<19:19:45, 1038.59s/it, test=10/11]     

32: 99.51529737541283 / 99.45346008647572


 34%|███▍      | 34/100 [9:47:48<18:59:10, 1035.62s/it, test=10/11]     

33: 99.51308466590487 / 99.50170300223611


 35%|███▌      | 35/100 [10:04:48<18:36:50, 1030.93s/it, test=10/11]     

34: 99.51746587722447 / 99.49666370045055


 36%|███▌      | 36/100 [10:21:49<18:16:14, 1027.73s/it, test=10/11]     

35: 99.51766900126022 / 99.48758645491166


 37%|███▋      | 37/100 [10:38:53<17:58:04, 1026.73s/it, test=10/11]     

36: 99.5193828551474 / 99.4960297237743


 38%|███▊      | 38/100 [10:56:16<17:45:59, 1031.61s/it, test=10/11]     

37: 99.52006737758404 / 99.50844103639776


 41%|████      | 41/100 [11:48:23<17:01:27, 1038.78s/it, test=10/11]     

40: 99.51406107482099 / 99.46367740631104


 42%|████▏     | 42/100 [12:05:48<16:45:58, 1040.67s/it, test=10/11]     

41: 99.52018746827314 / 99.51043562455611


 43%|████▎     | 43/100 [12:23:10<16:28:49, 1040.87s/it, test=10/11]     

42: 99.52172613849041 / 99.51757409355857


 44%|████▍     | 44/100 [12:40:31<16:11:35, 1040.99s/it, test=10/11]     

43: 99.52133110071948 / 99.51706962151961


 45%|████▌     | 45/100 [12:57:51<15:53:57, 1040.68s/it, test=10/11]     

44: 99.52507420079765 / 99.49387907981873


 46%|████▌     | 46/100 [13:15:10<15:36:08, 1040.16s/it, test=10/11]     

45: 99.52546611510893 / 99.46998357772827


 47%|████▋     | 47/100 [13:32:19<15:15:46, 1036.72s/it, test=10/11]     

46: 99.52580290889564 / 99.43283904682507


 48%|████▊     | 48/100 [13:49:29<14:56:54, 1034.89s/it, test=10/11]     

47: 99.52707739061437 / 99.43259520964189


 49%|████▉     | 49/100 [14:06:50<14:41:14, 1036.76s/it, test=10/11]     

48: 99.52690529735163 / 99.45306452837858


 50%|█████     | 50/100 [14:24:10<14:24:43, 1037.67s/it, test=10/11]     

49: 99.52830711376204 / 99.51196258718318


 51%|█████     | 51/100 [14:41:26<14:07:06, 1037.28s/it, test=10/11]     

50: 99.5286413680163 / 99.49113130569458


 52%|█████▏    | 52/100 [14:58:45<13:50:07, 1037.65s/it, test=10/11]     

51: 99.52913857652167 / 99.50723919001493


 53%|█████▎    | 53/100 [15:15:58<13:31:46, 1036.31s/it, test=10/11]     

52: 99.52947923744892 / 99.5022638277574


 54%|█████▍    | 54/100 [15:33:16<13:14:53, 1036.82s/it, test=10/11]     

53: 99.52281024892316 / 99.50450442054056


 55%|█████▌    | 55/100 [15:50:38<12:58:40, 1038.24s/it, test=10/11]     

54: 99.52946056830464 / 99.52128963036971


 56%|█████▌    | 56/100 [16:07:35<12:36:42, 1031.88s/it, test=10/11]     

55: 99.531194292243 / 99.42349032922225


 57%|█████▋    | 57/100 [16:24:32<12:16:20, 1027.45s/it, test=10/11]     

56: 99.52703748524961 / 99.50162389061667


 58%|█████▊    | 58/100 [16:41:50<12:01:24, 1030.59s/it, test=10/11]     

57: 99.530837347459 / 99.48106679049405


 59%|█████▉    | 59/100 [16:59:08<11:45:54, 1033.03s/it, test=10/11]     

58: 99.53208453333532 / 99.5209590955214


 60%|██████    | 60/100 [17:16:27<11:29:51, 1034.78s/it, test=10/11]     

59: 99.53245953579267 / 99.52196045355363


 61%|██████    | 61/100 [17:34:03<11:16:40, 1041.03s/it, test=10/11]     

60: 99.53274478423177 / 99.51523488218135


 62%|██████▏   | 62/100 [17:51:20<10:58:37, 1039.94s/it, test=10/11]     

61: 99.53281348932693 / 99.48879046873613


 63%|██████▎   | 63/100 [18:08:17<10:37:02, 1033.04s/it, test=10/11]     

62: 99.53043961591068 / 99.50822537595576


 74%|███████▍  | 74/100 [21:17:47<7:27:36, 1032.96s/it, test=10/11]     

73: 99.53201231365944 / 99.40469156612049


 75%|███████▌  | 75/100 [21:35:05<7:11:07, 1034.72s/it, test=10/11]     

74: 99.53760905759391 / 99.44701248949224


 77%|███████▋  | 77/100 [22:09:47<6:37:49, 1037.79s/it, test=10/11]     

76: 99.53744548636311 / 99.14368391036987


 77%|███████▋  | 77/100 [22:20:10<6:37:49, 1037.79s/it, train=651/1082]