# Training

### Import Libraries

In [1]:
import os
import easydict
from PIL import Image as im
from PIL import ImageOps
import cv2
from tqdm import tqdm
import json
import sys
import shutil
from datetime import datetime
import pandas as pd
import random

op = os.path.join

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler

In [3]:
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, roc_curve
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from tensorboardX import SummaryWriter

In [4]:
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, roc_curve
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, f1_score
from tensorboardX import SummaryWriter

In [5]:
from pycocotools.coco import COCO

### Set Directory Path

In [6]:
DATA_DIR = "/home/dmsai2/mmdetection/data/"

TRAIN_DATASET_DIR = op(DATA_DIR, "classification")
TRAIN_IMAGE_DIR = op(TRAIN_DATASET_DIR, "train")
TRAIN_JSON_DIR = op(TRAIN_DATASET_DIR, "annotations")

TEST_DATASET_DIR = op(DATA_DIR, "classification")
TEST_IMAGE_DIR = op(TEST_DATASET_DIR, "test")
TEST_JSON_DIR = op(TEST_DATASET_DIR, "annotations")

current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime('%Y-%m-%d_%H_%M')
print(formatted_datetime)

os.makedirs(op("/home/dmsai2/mmdetection/work_dir/classification/logs", formatted_datetime), exist_ok=True)
os.makedirs(op("/home/dmsai2/mmdetection/work_dir/classification/weights", formatted_datetime), exist_ok=True)

2024-06-19_05_43


### Read File List

In [7]:
train_file_list = list(map(lambda x : x.split(".")[0], os.listdir(TRAIN_IMAGE_DIR)))
train_image_list = os.listdir(TRAIN_IMAGE_DIR)
train_json_list = os.listdir(TRAIN_JSON_DIR)
print("number of train file:", len(train_file_list))

test_file_list = list(map(lambda x : x.split(".")[0], os.listdir(TEST_IMAGE_DIR)))
test_image_list = os.listdir(TEST_IMAGE_DIR)
test_json_list = os.listdir(TEST_JSON_DIR)
print("number of train file:", len(test_file_list))

number of train file: 1991
number of train file: 250


### Undersampling

In [8]:
# train_df = pd.read_csv("/hoeunlee228/Dataset/train_df.csv")
# # print(train_df.head())
# train_df['is_decayed'].value_counts()

# not_decayed_rows = train_df[train_df['is_decayed'] == False]
# decayed_rows = train_df[train_df['is_decayed'] == True]
# print("total not decayed rows:", len(not_decayed_rows), "total decayed rows", len(decayed_rows))

# num_samples_not_decayed = 28136 * 3
# num_samples_decayed = 28136

# print("args num sampled not decayed:", num_samples_not_decayed, "args num sampled decayed:", num_samples_decayed)

# random_samples_not_decayed = random.sample(range(len(not_decayed_rows)), num_samples_not_decayed)
# not_decayed_sampled_df = not_decayed_rows.iloc[random_samples_not_decayed]
# print("num of not decayed sampled list", len(not_decayed_sampled_df))

# random_samples_decayed = random.sample(range(len(decayed_rows)), num_samples_decayed)
# decayed_sampled_df = decayed_rows.iloc[random_samples_decayed]
# print("num of decayed sampled list", len(decayed_sampled_df))

# train_file_list_not_decayed_sampled = list(map(lambda x : f"{x[0]}_{x[1]}", not_decayed_sampled_df[['file', 'teeth_idx']].values.tolist()))
# train_file_list_decayed_sampled = list(map(lambda x : f"{x[0]}_{x[1]}", decayed_sampled_df[['file', 'teeth_idx']].values.tolist()))
# print("num of sampled file list (not decayed, decayed):", len(train_file_list_not_decayed_sampled), len(train_file_list_decayed_sampled))

# sampled_train_list = train_file_list_not_decayed_sampled + train_file_list_decayed_sampled
# print("total num of sampled train list:", len(sampled_train_list))

### Custom Dataset

In [9]:
# class ToothDataset(Dataset):
#     def __init__(self, data_dir, ann_dir, file_list, transform=None, aug_transform=None, valid=False):
#         self.data_dir = data_dir
#         self.ann_dir = ann_dir
#         self.file_list = file_list
#         self.transform = transform
#         self.aug_transform = aug_transform
#         self.valid = valid

#     def __len__(self):
#         return len(self.file_list)
    
#     def _load_image(self, image_path):
#         assert os.path.exists(op(self.data_dir, image_path))
#         # return cv2.cvtColor(cv2.imread(op(self.data_dir, "image", image_path)), cv2.COLOR_BGR2RGB)
#         if self.valid:
#             return (im.open(op(self.data_dir, image_path)).convert("RGB"), image_path)
#         else:
#             return im.open(op(self.data_dir, image_path)).convert("RGB")
    
#     def __getitem__(self, index):
#         image_path = self.file_list[index] + ".png"
#         json_path = self.file_list[index] + ".json"

#         if self.valid:
#             image, label = self._load_image(image_path)
#         else:
#             image = self._load_image(image_path)

#         with open(op(self.ann_dir, json_path), 'r') as json_file:
#             data = json.load(json_file)

#         decayed = data["tooth"][0]["decayed"]
#         target = 1 if decayed else 0

#         if self.transform:
#             image = self.transform(image)

#         if self.aug_transform:
#             image = self.aug_transform(image)

#         if self.valid:
#             return (image, label), target
#         else:
#             return image, target

In [10]:
class ToothCOCODataset(Dataset):
    def __init__(self, data_dir, ann_file, transform=None, aug_transform=None, valid=False):
        self.data_dir = data_dir
        self.coco = COCO(ann_file)
        self.image_ids = self.coco.getImgIds()
        self.transform = transform
        self.aug_transform = aug_transform
        self.valid = valid

    def __len__(self):
        return len(self.image_ids)
    
    def _load_image(self, image_id):
        image_info = self.coco.loadImgs(image_id)[0]
        image_path = op(self.data_dir, image_info['file_name'])
        assert os.path.exists(image_path), f"Image path {image_path} does not exist."
        
        if self.valid:
            return (im.open(image_path).convert("RGB"), image_info['file_name'])
        else:
            return im.open(image_path).convert("RGB")
    
    def _load_target(self, image_id):
        ann_ids = self.coco.getAnnIds(imgIds=image_id)
        anns = self.coco.loadAnns(ann_ids)
        # Assuming 'decayed' is the attribute that indicates decay
        # decayed = any(ann.get('category_id', False) for ann in anns)
        decayed = any(ann.get('category_id', False) for ann in anns)
        target = 1 if decayed else 0
        return target

    def __getitem__(self, index):
        image_id = self.image_ids[index]

        if self.valid:
            image, label = self._load_image(image_id)
        else:
            image = self._load_image(image_id)

        target = self._load_target(image_id)

        if self.transform:
            image = self.transform(image)

        if self.aug_transform:
            image = self.aug_transform(image)

        if self.valid:
            return (image, label), target
        else:
            return image, target

In [11]:
class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x


class Xception(nn.Module):
    def __init__(self, num_classes=1000):
        super(Xception, self).__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3,32,3,2,0,bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        self.bn3 = nn.BatchNorm2d(1536)

        self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

    def features(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)
        return x

    def logits(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1)) 
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x


## 기존 Xception에 Dropout만 추가
class xception(nn.Module):
    def __init__(self, num_out_classes=2, dropout=0.5):
        super(xception, self).__init__()

        self.model = Xception(num_classes=num_out_classes)
        self.model.last_linear = self.model.fc
        del self.model.fc

        num_ftrs = self.model.last_linear.in_features
        if not dropout:
            self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
        else:            
            self.model.last_linear = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(num_ftrs, num_out_classes)
            )

    def forward(self, x):
        x = self.model(x)
        return x

In [12]:
cudnn.benchmark = True

args = easydict.EasyDict({
    "gpu": 0,
    "num_workers": 4,
    "root": "/home/dmsai2/mmdetection/work_dir/classification/",
    "learning_rate": 1e-4,
    "num_epochs": 50,
    "batch_size": 16,

    "save_fn": f"/home/dmsai2/mmdetection/work_dir/classification/weights/{formatted_datetime}/xception",
    "load_fn": None,
    "scheduler": None,

    "scheduler_step": 1,
    "scheduler_gamma": 0.001
})

In [13]:
# 1. Zero padding to make the image square
def make_square(img):
    width, height = img.size
    max_side = max(width, height)
    left = (max_side - width) // 2
    top = (max_side - height) // 2
    right = (max_side - width) - left
    bottom = (max_side - height) - top
    padding = (left, top, right, bottom)
    return ImageOps.expand(img, padding)

In [14]:
mean = (0.57933619, 0.42688786, 0.33401168)
std = (0.35580848, 0.27125023, 0.22251333)

"""train_transform = transforms.Compose([
    transforms.Resize(size=(299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.Resize(size=(299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

valid_transform = transforms.Compose([
    transforms.Resize(size=(299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])"""

train_aug_transform = transforms.Compose([
    # 1. 이미지를 정사각형으로 만들기
    transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),  # Exif 정보 처리
    transforms.Lambda(make_square),

    # 2. Resize
    transforms.Resize(size=(299, 299)),

    # 3. Augmentation
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),

    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_aug_transform = transforms.Compose([
    # 1. 이미지를 정사각형으로 만들기
    transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),  # Exif 정보 처리
    transforms.Lambda(make_square),

    # 2. Resize
    transforms.Resize(size=(299, 299)),
    
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

valid_aug_transform = transforms.Compose([
    # 1. 이미지를 정사각형으로 만들기
    transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),  # Exif 정보 처리
    transforms.Lambda(make_square),

    # 2. Resize
    transforms.Resize(size=(299, 299)),

    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [15]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    print("pth file saved at " + filename)
    if is_best:
        shutil.copyfile(filename, f"/home/dmsai2/mmdetection/work_dir/classification/weights/{formatted_datetime}/model_best.pth.tar")

def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [16]:
# train / validate
def train(train_loader, model, criterion, optimizer, epoch, writer, step):   
    n = 0
    running_loss = 0.0
    running_corrects = 0

    all_targets = []
    all_preds = []

    error_count = {
        "precision": 0,
        "recall": 0,
        "auc": 0
    }

    model.train()

    with tqdm(train_loader, total=len(train_loader), desc="Train", file=sys.stdout) as iterator:
        for images, target in iterator:
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            outputs = model(images)
            _, pred = torch.max(outputs.data, 1)

            loss = criterion(outputs, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n += images.size(0)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(pred == target.data)

            epoch_loss = running_loss / float(n)
            epoch_acc = running_corrects / float(n)

            all_targets.extend(target.cpu().numpy())
            all_preds.extend(pred.cpu().numpy())

            # Calculate F1-Score
            f1 = f1_score(all_targets, all_preds, average='weighted')

            # Calculate Precision and Recall
            try:
                precision = precision_score(all_targets, all_preds, average='weighted', zero_division=1)
            except Exception as e:
                precision = -1.0
                error_count["precision"] += 1

            try:
                recall = recall_score(all_targets, all_preds, average='weighted')
            except Exception as e:
                recall = -1.0
                error_count["recall"] += 1

            # Calculate AUC
            # fpr, tpr, _ = roc_curve(target.cpu(), outputs.data[:, 1].cpu())
            # auc_value = roc_auc_score(target.cpu(), outputs.data[:, 1].cpu())
            try:
                # fpr, tpr, _ = roc_curve(all_targets, all_preds)
                auc_value = roc_auc_score(all_targets, all_preds)
            except Exception as e:
                auc_value = -1.0
                error_count["auc"] += 1

            log = 'loss - {:.4f}, acc - {:.4f}, F1 - {:.4f}, Precision - {:.4f}, Recall - {:.4f}, AUC - {:.4f}'.format(epoch_loss, epoch_acc, f1, precision, recall, auc_value)
            iterator.set_postfix_str(log)

            if step % 10 == 0:
                writer.add_scalar("Train/Step/Loss", epoch_loss, step)
                writer.add_scalar("Train/Step/Accuracy", epoch_acc, step)
                writer.add_scalar("Train/Step/F1-Score", f1, step)
                writer.add_scalar("Train/Step/Precision", precision, step)
                writer.add_scalar("Train/Step/Recall", recall, step)
                writer.add_scalar("Train/Step/AUC", auc_value, step)

            step += 1

    writer.add_scalar("Train/Epoch/Loss", epoch_loss, epoch)
    writer.add_scalar("Train/Epoch/Accuracy", epoch_acc, epoch)
    writer.add_scalar("Train/Epoch/F1-Score", f1, epoch)
    writer.add_scalar("Train/Epoch/Precision", precision, epoch)
    writer.add_scalar("Train/Epoch/Recall", recall, epoch)
    writer.add_scalar("Train/Epoch/AUC", auc_value, epoch)

    print(error_count)

    # scheduler.step()

    return step


def validate(test_loader, model, criterion, epoch, writer):

    all_lables = []

    n = 0
    running_loss = 0.0
    running_corrects = 0

    all_targets = []
    all_preds = []

    error_count = {
        "precision": 0,
        "recall": 0,
        "auc": 0
    }

    model.eval()

    with tqdm(test_loader, total=len(test_loader), desc="Valid", file=sys.stdout) as iterator:
        for images, target in iterator:

            images, labels = images[0], images[1]

            all_lables.extend(labels)

            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            with torch.no_grad():
                output = model(images)

            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)

            n += images.size(0)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(pred == target.data)

            epoch_loss = running_loss / float(n)
            epoch_acc = running_corrects / float(n)

            all_targets.extend(target.cpu().numpy())
            all_preds.extend(pred.cpu().numpy())

            # Calculate F1-Score
            f1 = f1_score(all_targets, all_preds, average='weighted')

            try:
                precision = precision_score(all_targets, all_preds, average='weighted', zero_division=1)
            except Exception as e:
                precision = -1.0
                error_count["precision"] += 1
            
            try:
                recall = recall_score(all_targets, all_preds, average='weighted')
            except Exception as e:
                recall = -1.0
                error_count["recall"] += 1

            # Calculate AUC
            # fpr, tpr, _ = roc_curve(target.cpu(), output.data[:, 1].cpu())
            # auc_value = roc_auc_score(target.cpu(), output.data[:, 1].cpu())
            try:
                # fpr, tpr, _ = roc_curve(all_targets, all_preds)
                auc_value = roc_auc_score(all_targets, all_preds)
            except Exception as e:
                auc_value = -1.0
                error_count["auc"] += 1

            log = 'loss - {:.4f}, acc - {:.4f}, F1 - {:.4f}, Precision - {:.4f}, Recall - {:.4f}, AUC - {:.4f}'.format(epoch_loss, epoch_acc, f1, precision, recall, auc_value)
            iterator.set_postfix_str(log)

    writer.add_scalar("Validation/Epoch/Loss", epoch_loss, epoch)
    writer.add_scalar("Validation/Epoch/Accuracy", epoch_acc, epoch)
    writer.add_scalar("Validation/Epoch/F1-Score", f1, epoch)
    writer.add_scalar("Validation/Epoch/Precision", precision, epoch)
    writer.add_scalar("Validation/Epoch/Recall", recall, epoch)
    writer.add_scalar("Validation/Epoch/AUC", auc_value, epoch)

    print(error_count)

    return epoch_acc, all_lables, all_preds, f1

In [17]:
# TensorboardX를 사용하여 summary writer 생성
train_writer = SummaryWriter(f"/home/dmsai2/mmdetection/work_dir/classification/logs/{formatted_datetime}/train")  # 'logs/train'는 로그가 저장될 디렉토리입니다.
test_writer = SummaryWriter(f"/home/dmsai2/mmdetection/work_dir/classification/logs/{formatted_datetime}/validation")  # 'logs/validation'은 로그가 저장될 디렉토리입니다.

In [18]:
model = xception(num_out_classes=2, dropout=0.5)
print("=> creating model '{}'".format('xception'))
model = model.cuda(args.gpu)

=> creating model 'xception'


In [19]:
if args.load_fn is not None:
    assert os.path.isfile(args.load_fn), 'wrong path'

    model.load_state_dict(torch.load(args.load_fn)['state_dict'])
    print("=> model weight '{}' is loaded".format(args.load_fn))

In [20]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08)

In [21]:
if args.scheduler == "steplr":
    print("StepLR selected to Scheduler")
    scheduler = lr_scheduler.StepLR(optimizer, 
                                    step_size=args.scheduler_step,
                                    gamma=args.scheduler_gamma)
elif args.scheduler == 'exp':
    print("ExponentialLR selected to Scheduler")
    scheduler = lr_scheduler.ExponentialLR(optimizer, 
                                    gamma=args.scheduler_gamma)

In [22]:
train_file_list = list(map(lambda x : x.split(".")[0], os.listdir("/home/dmsai2/mmdetection/data/classification/train")))
test_file_list = list(map(lambda x : x.split(".")[0], os.listdir("/home/dmsai2/mmdetection/data/classification/test")))
print(len(train_file_list), len(test_file_list))

# train_dataset = ToothCOCODataset(data_dir="/home/dmsai2/mmdetection/data/",
#                              file_list=train_file_list,
#                              transform=train_aug_transform)

# valid_dataset = ToothCOCODataset(data_dir="/home/dmsai2/mmdetection/data/",
#                              file_list=test_file_list,
#                              transform=test_aug_transform, valid=True)

train_dataset = ToothCOCODataset(data_dir="/home/dmsai2/mmdetection/data/classification/train", 
                                 ann_file="/home/dmsai2/mmdetection/data/classification/annotations/train.json", 
                                 aug_transform=train_aug_transform, valid=False)

valid_dataset = ToothCOCODataset(data_dir="/home/dmsai2/mmdetection/data/classification/test", 
                                 ann_file="/home/dmsai2/mmdetection/data/classification/annotations/test.json", 
                                 aug_transform=train_aug_transform, valid=True)

1991 250
loading annotations into memory...
Done (t=0.29s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


In [23]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           pin_memory=True,
                                           )

valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           pin_memory=False,
                                           )

In [24]:
print("checking dataset and dataloader is okay...")

try:
    _ = next(iter(train_loader))
    print("ok")
except Exception as e:
    print(e)
    print("dataset or dataloader is not ok")

checking dataset and dataloader is okay...
ok


In [25]:
print("start training...")
step = 0

for epoch in range(args.num_epochs):
    print('-' * 50)
    print('Epoch {}/{}'.format(epoch, args.num_epochs))
    step = train(train_loader, model, criterion, optimizer, epoch, train_writer, step)
    acc = validate(valid_loader, model, criterion, epoch, test_writer)

    save_checkpoint(state={'epoch': epoch,
                           'state_dict': model.state_dict(),
                           'best_acc1': acc,
                           'optimizer': optimizer.state_dict(),},
                        is_best=False,
                        filename=args.save_fn + f"_epoch{epoch}.pth",
                    )

start training...
--------------------------------------------------
Epoch 0/50
Train: 100%|██████████| 125/125 [00:58<00:00,  2.14it/s, loss - 0.6675, acc - 0.6158, F1 - 0.5790, Precision - 0.5756, Recall - 0.6158, AUC - 0.5268]
{'precision': 0, 'recall': 0, 'auc': 0}
Valid: 100%|██████████| 16/16 [00:09<00:00,  1.75it/s, loss - 0.6793, acc - 0.5440, F1 - 0.5430, Precision - 0.5420, Recall - 0.5440, AUC - 0.5098]
{'precision': 0, 'recall': 0, 'auc': 0}
pth file saved at /home/dmsai2/mmdetection/work_dir/classification/weights/2024-06-19_05_43/xception_epoch0.pth
--------------------------------------------------
Epoch 1/50
Train: 100%|██████████| 125/125 [00:55<00:00,  2.25it/s, loss - 0.6400, acc - 0.6419, F1 - 0.6082, Precision - 0.6109, Recall - 0.6419, AUC - 0.5560]
{'precision': 0, 'recall': 0, 'auc': 0}
Valid: 100%|██████████| 16/16 [00:08<00:00,  1.87it/s, loss - 0.6605, acc - 0.6200, F1 - 0.5544, Precision - 0.5761, Recall - 0.6200, AUC - 0.5243]
{'precision': 0, 'recall': 0, 

  _warn_prf(average, modifier, msg_start, len(result))


Train: 100%|██████████| 125/125 [00:55<00:00,  2.27it/s, loss - 0.4990, acc - 0.7745, F1 - 0.7677, Precision - 0.7698, Recall - 0.7745, AUC - 0.7303]
{'precision': 0, 'recall': 0, 'auc': 1}
Valid: 100%|██████████| 16/16 [00:08<00:00,  1.84it/s, loss - 0.6183, acc - 0.7080, F1 - 0.7128, Precision - 0.7339, Recall - 0.7080, AUC - 0.7171]
{'precision': 0, 'recall': 0, 'auc': 0}
pth file saved at /home/dmsai2/mmdetection/work_dir/classification/weights/2024-06-19_05_43/xception_epoch8.pth
--------------------------------------------------
Epoch 9/50
Train: 100%|██████████| 125/125 [00:56<00:00,  2.22it/s, loss - 0.4743, acc - 0.7885, F1 - 0.7836, Precision - 0.7846, Recall - 0.7885, AUC - 0.7499]
{'precision': 0, 'recall': 0, 'auc': 0}
Valid: 100%|██████████| 16/16 [00:08<00:00,  1.95it/s, loss - 0.5374, acc - 0.7480, F1 - 0.7366, Precision - 0.7450, Recall - 0.7480, AUC - 0.7007]
{'precision': 0, 'recall': 0, 'auc': 0}
pth file saved at /home/dmsai2/mmdetection/work_dir/classification/wei

RuntimeError: [enforce fail at inline_container.cc:595] . unexpected pos 231007296 vs 231007184

In [None]:
acc, all_labels, all_preds, f1 = validate(valid_loader, model, criterion, 0, test_writer)

In [None]:
print(f1)