In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [22]:
import sys, torch, os, json, copy
from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
sys.path.append('/content/drive/MyDrive/Colab Notebooks/hemp/download/model')
from quant import QuantizableMobileNetV4
import torch.optim as optim
from tqdm import tqdm
from torch import nn, Tensor
from mobilenet import mobilenetv4_conv_medium
import torch.nn.utils.prune as prune

# **데이터셋 정의**

In [49]:
input_dir = '/content/drive/MyDrive/Colab Notebooks/hemp/download/dataset/seg_images'
label_dir = '/content/drive/MyDrive/Colab Notebooks/hemp/download/dataset/labels'

def load_dataset(input_dir, label_dir):
    intput_files = os.listdir(input_dir)
    label_files = os.listdir(label_dir)

    dataset = []
    for input_file in intput_files:
        label_file = input_file.replace('.png', '.json')
        with open(os.path.join(label_dir, label_file), 'r') as f:
            label_data = json.load(f)
            browning = label_data['annotations']['polygon'][0]['browning']

        dataset.append({
            'image_path': os.path.join(input_dir, input_file),
            'label': f'{browning}'
        })
    return dataset

dataset = load_dataset(input_dir, label_dir)
labels = [data['label'] for data in dataset]

train_split = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
for train_idx, temp_idx in train_split.split(dataset, labels):
    train_set = [dataset[i] for i in train_idx]
    temp_set = [dataset[i] for i in temp_idx]
val_test_split = StratifiedShuffleSplit(n_splits=1, test_size=2/3, random_state=42)
labels_temp = [labels[i] for i in temp_idx]
for val_idx, test_idx in val_test_split.split(temp_set, labels_temp):
    val_set = [temp_set[i] for i in val_idx]
    test_set = [temp_set[i] for i in test_idx]

train_label_count = Counter([data['label'] for data in train_set])
val_label_count   = Counter([data['label'] for data in val_set])
test_label_count  = Counter([data['label'] for data in test_set])

print("Train set class counts:", train_label_count)
print("Validation set class counts:", val_label_count)
print("Test set class counts:", test_label_count)

Train set class counts: Counter({'True': 35, 'False': 35})
Validation set class counts: Counter({'False': 5, 'True': 5})
Test set class counts: Counter({'False': 10, 'True': 10})


In [50]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=MEAN, std=STD)
transform = transforms.Compose([
    transforms.ToTensor()])

def preprocess_image(image):
    non_zero_mask = (image > 0).any(axis=-1)
    non_zero_rows = np.any(non_zero_mask, axis=1)
    non_zero_cols = np.any(non_zero_mask, axis=0)
    min_row, max_row = np.where(non_zero_rows)[0][[0, -1]]
    min_col, max_col = np.where(non_zero_cols)[0][[0, -1]]
    image = image[min_row:max_row+1, min_col:max_col+1]
    image = Image.fromarray(image)
    image = image.resize((224, 224))
    return image

class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, num_classes=2):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.num_classes = num_classes

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert('RGB'))
        image = preprocess_image(image)
        label = int(self.labels[idx])

        if self.transform:
            image = self.transform(image)
        return image, label

def create_dataloader(dataset, transform, batch_size, shuffle, num_workers):
    image_paths = [item['image_path'] for item in dataset]
    labels = [0 if item['label'] == 'True' else 1 for item in dataset]
    custom_dataset = CustomDataset(image_paths, labels, transform=transform)
    return DataLoader(custom_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

train_loader = create_dataloader(train_set, transform, batch_size=24, shuffle=True, num_workers=2)
val_loader = create_dataloader(val_set, transform, batch_size=24, shuffle=False, num_workers=2)
test_loader = create_dataloader(test_set, transform, batch_size=24, shuffle=False, num_workers=2)

# **FP32**

In [51]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

cpu_device = torch.device("cpu")
CE = nn.CrossEntropyLoss()

pretrained_path = '/content/drive/MyDrive/Colab Notebooks/hemp/download/model/pretrain.pth'
FP32 = mobilenetv4_conv_medium(num_classes=2)
FP32.load_state_dict(torch.load(pretrained_path, map_location=cpu_device))
FP32.eval()

total_loss = 0.0
correct_predictions = 0
all_labels = []
all_preds = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Test Progress', leave=False):
        inputs = normalize(inputs).to(cpu_device)
        labels = labels.to(cpu_device)

        outputs = FP32(inputs)
        loss = CE(outputs, labels)
        _, preds = torch.max(outputs, 1)

        total_loss += loss.item() * inputs.size(0)
        correct_predictions += (preds == labels).sum().item()

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

avg_test_loss = total_loss / len(test_loader.dataset)
avg_test_acc = accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=1)
test_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=1)
test_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=1)

print(f"Test Metrics: "
      f"Loss={avg_test_loss:.4f}, "
      f"Acc={avg_test_acc:.4f}, "
      f"Precision={test_precision:.4f}, "
      f"Recall={test_recall:.4f}, "
      f"F1={test_f1:.4f}")

print_model_size(FP32)

  FP32.load_state_dict(torch.load(pretrained_path, map_location=cpu_device))
                                                            

Test Metrics: Loss=0.1260, Acc=0.9500, Precision=0.9545, Recall=0.9500, F1=0.9499
34.15 MB




# **비구조적 가지치기 (Unstructured pruning)**

In [67]:
model = QuantizableMobileNetV4(num_classes=2)
model.to(cpu_device)

model.load_state_dict(torch.load(pretrained_path, map_location=cpu_device))

model.eval()
model.fuse_model()
model.train()

model.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
torch.ao.quantization.prepare_qat(model, inplace=True)

ori_model = copy.deepcopy(model)

  model.load_state_dict(torch.load(pretrained_path, map_location=cpu_device))


In [86]:
prunned_model = copy.deepcopy(ori_model)

for name, module in prunned_model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        prune.l1_unstructured(module, name='weight', amount=0.3)

def count_sparsity(model):
    total_params = 0
    total_zero = 0
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            weight_sparsity = float(torch.sum(module.weight == 0)) / module.weight.numel()
            print(f"{name}: weight sparsity={weight_sparsity*100:.2f}%")
            total_zero += torch.sum(module.weight == 0).item()
            total_params += module.weight.numel()
    print(f"Total zeros: {total_zero} out of {total_params} ({(total_zero / total_params) * 100:.2f}%)")

count_sparsity(prunned_model)

features.0.block.0: weight sparsity=29.98%
features.1.block.0: weight sparsity=30.00%
features.2.block.0: weight sparsity=30.00%
features.3.start_dw_conv: weight sparsity=30.09%
features.3.expand_conv: weight sparsity=30.00%
features.3.middle_dw_conv: weight sparsity=30.00%
features.3.proj_conv: weight sparsity=30.00%
features.4.start_dw_conv: weight sparsity=30.00%
features.4.expand_conv: weight sparsity=30.00%
features.4.middle_dw_conv: weight sparsity=30.00%
features.4.proj_conv: weight sparsity=30.00%
features.5.start_dw_conv: weight sparsity=30.00%
features.5.expand_conv: weight sparsity=30.00%
features.5.middle_dw_conv: weight sparsity=30.00%
features.5.proj_conv: weight sparsity=30.00%
features.6.start_dw_conv: weight sparsity=30.00%
features.6.expand_conv: weight sparsity=30.00%
features.6.middle_dw_conv: weight sparsity=30.00%
features.6.proj_conv: weight sparsity=30.00%
features.7.start_dw_conv: weight sparsity=30.00%
features.7.expand_conv: weight sparsity=30.00%
features.7.

In [87]:
for name, module in prunned_model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        if hasattr(module, 'weight_mask'):
            prune.remove(module, 'weight')
prunned_model.eval()

total_loss = 0.0
correct_predictions = 0
all_labels = []
all_preds = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Test Progress', leave=False):
        inputs = normalize(inputs).to(cpu_device)
        labels = labels.to(cpu_device)

        outputs = prunned_model(inputs)
        loss = CE(outputs, labels)
        _, preds = torch.max(outputs, 1)

        total_loss += loss.item() * inputs.size(0)
        correct_predictions += (preds == labels).sum().item()

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

avg_test_loss = total_loss / len(test_loader.dataset)
avg_test_acc = accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=1)
test_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=1)
test_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=1)

print(f"Test Metrics: "
      f"Loss={avg_test_loss:.4f}, "
      f"Acc={avg_test_acc:.4f}, "
      f"Precision={test_precision:.4f}, "
      f"Recall={test_recall:.4f}, "
      f"F1={test_f1:.4f}")

print_model_size(prunned_model)



Test Metrics: Loss=0.6137, Acc=0.5000, Precision=0.7500, Recall=0.5000, F1=0.3333
35.13 MB


# **양자화 인지 훈련 (Quantization Aware Training)**

In [83]:
num_epochs = 50
patience = 5
best_val_acc = 0.0
best_val_loss = float('inf')
epochs_no_improve = 0
early_stop = False

QAT_model = copy.deepcopy(ori_model)

for params in QAT_model.parameters():
    params.requires_grad = True

for name, module in QAT_model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        prune.l1_unstructured(module, name='weight', amount=0.3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
QAT_model.to(device)

optimizer = optim.Adam(QAT_model.parameters(), lr=0.001)

def train_val_epoch(model, dataloader, phase, optimizer=None):
    if phase == 'train':
        model.train()
    else:
        model.eval()

    running_loss = 0.0
    all_preds = []
    all_labels = []

    for inputs, labels in tqdm(dataloader, desc=f'  {phase} Progress', leave=False):
        inputs = normalize(inputs).to(device)
        labels = labels.to(device)

        if phase == 'train':
            optimizer.zero_grad()

        with torch.set_grad_enabled(phase == 'train'):
            outputs = model(inputs)
            loss = CE(outputs, labels)
            _, preds = torch.max(outputs, 1)

            if phase == 'train':
                loss.backward()
                optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        all_preds.extend(preds.detach().cpu().numpy())
        all_labels.extend(labels.detach().cpu().numpy())

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=1)
    epoch_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=1)
    epoch_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=1)

    epoch_metrics = {
        'loss': epoch_loss,
        'accuracy': epoch_acc,
        'precision': epoch_precision,
        'recall': epoch_recall,
        'f1': epoch_f1
    }

    return epoch_loss, epoch_acc, epoch_metrics

best_model_wts = copy.deepcopy(QAT_model.state_dict())

for epoch in range(num_epochs):
    if early_stop:
        print(f'Early stopping at epoch {epoch + 1}')
        break

    print(f"Epoch {epoch+1}/{num_epochs}:")

    train_loss, train_acc, train_metrics = train_val_epoch(QAT_model, train_loader, 'train', optimizer)

    val_loss, val_acc, val_metrics = train_val_epoch(QAT_model, val_loader, 'val')

    print(f"  Train Metrics:      Loss={train_loss:.4f}, Acc={train_acc:.4f}, "
          f"Precision={train_metrics['precision']:.4f}, "
          f"Recall={train_metrics['recall']:.4f}, "
          f"F1={train_metrics['f1']:.4f}")
    print(f"  Validation Metrics: Loss={val_loss:.4f}, Acc={val_acc:.4f}, "
          f"Precision={val_metrics['precision']:.4f}, "
          f"Recall={val_metrics['recall']:.4f}, "
          f"F1={val_metrics['f1']:.4f}")

    # 모델 저장 및 조기 종료 조건
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(QAT_model.state_dict())
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"  Patience: {epochs_no_improve}")
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            early_stop = True

Epoch 1/2:




  Train Metrics:      Loss=0.5912, Acc=0.8286, Precision=0.8723, Recall=0.8286, F1=0.8234
  Validation Metrics: Loss=0.3781, Acc=0.7000, Precision=0.8125, Recall=0.7000, F1=0.6703
Epoch 2/2:


                                                             

  Train Metrics:      Loss=0.1897, Acc=0.9000, Precision=0.9167, Recall=0.9000, F1=0.8990
  Validation Metrics: Loss=0.3754, Acc=0.7000, Precision=0.8125, Recall=0.7000, F1=0.6703




In [84]:
QAT_model.load_state_dict(best_model_wts)
for name, module in QAT_model.named_modules():
    if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
        if hasattr(module, 'weight_mask'):
            prune.remove(module, 'weight')

QAT_model.to(cpu_device)
int_model = torch.ao.quantization.convert(QAT_model.eval(), inplace=False)
int8_path = '/content/drive/MyDrive/Colab Notebooks/hemp/model/int8.pth'
torch.save(int_model.state_dict(), int8_path)



# **양자화 모델 추론**

In [85]:
int_model = copy.deepcopy(ori_model)
int_model.to(cpu_device)
int_model = torch.ao.quantization.convert(int_model.eval(), inplace=False)
int_model.load_state_dict(torch.load(int8_path))
int_model.eval()

total_loss = 0.0
correct_predictions = 0
all_labels = []
all_preds = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc='Test Progress', leave=False):
        inputs = normalize(inputs).to(cpu_device)
        labels = labels.to(cpu_device)

        outputs = int_model(inputs)
        loss = CE(outputs, labels)
        _, preds = torch.max(outputs, 1)

        total_loss += loss.item() * inputs.size(0)
        correct_predictions += (preds == labels).sum().item()

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

avg_test_loss = total_loss / len(test_loader.dataset)
avg_test_acc = accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=1)
test_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=1)
test_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=1)

print(f"Test Metrics: "
      f"Loss={avg_test_loss:.4f}, "
      f"Acc={avg_test_acc:.4f}, "
      f"Precision={test_precision:.4f}, "
      f"Recall={test_recall:.4f}, "
      f"F1={test_f1:.4f}")

print_model_size(int_model)

  int_model.load_state_dict(torch.load(int8_path))
                                                            

Test Metrics: Loss=0.2408, Acc=0.9500, Precision=0.9545, Recall=0.9500, F1=0.9499
9.19 MB


