In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt

import os
import multiprocessing as mp
from tqdm import tqdm
import time

# For feature extraction: discrete wavelet transform
import pywt

# For CNN
import torch
import torch.nn as nn
import torch.optim as optim

### Reading images

In [4]:
with open("./train.txt") as f:
    train_file_list = f.readlines()
with open("./val.txt") as f:
    val_file_list = f.readlines()
with open("./test.txt") as f:
    test_file_list = f.readlines()

train_file_list = [x.strip().split(sep=" ") for x in train_file_list]
val_file_list = [x.strip().split(sep=" ") for x in val_file_list]
test_file_list = [x.strip().split(sep=" ") for x in test_file_list]

train_label = [int(x[1]) for x in train_file_list]
val_label = [int(x[1]) for x in val_file_list]
test_label = [int(x[1]) for x in test_file_list]

train_file_list = [x[0] for x in train_file_list]
val_file_list = [x[0] for x in val_file_list]
test_file_list = [x[0] for x in test_file_list]

In [5]:
print("# cpus: ", os.cpu_count())

# cpus:  64


In [6]:
NUM_PROCESSES = 8

In [7]:
def ReadImage(filePath):
    # image = cv2.imread(filePath, cv2.IMREAD_COLOR)
    image = cv2.imread(filePath, cv2.IMREAD_GRAYSCALE)
    # image = cv2.resize(image, (256, 256))
    return image

In [8]:
with mp.Pool(processes=NUM_PROCESSES) as pool:
    train_imgs = pool.map(ReadImage, tqdm(train_file_list))
    val_imgs = pool.map(ReadImage, tqdm(val_file_list))
    test_imgs = pool.map(ReadImage, tqdm(test_file_list))

100%|██████████| 63325/63325 [00:18<00:00, 3495.71it/s]
100%|██████████| 450/450 [00:00<00:00, 232328.51it/s]
100%|██████████| 450/450 [00:00<00:00, 474349.54it/s]


In [12]:
# resize the images to 256x256
def ResizeImage(image):
    # resized_img = cv2.resize(image, (256, 256))
    resized_img = cv2.resize(image, (128, 128))
    return resized_img

In [13]:
with mp.Pool(processes=NUM_PROCESSES) as pool:
    resized_train_imgs = pool.map(ResizeImage, tqdm(train_imgs))
    resized_val_imgs = pool.map(ResizeImage, tqdm(val_imgs))
    resized_test_imgs = pool.map(ResizeImage, tqdm(test_imgs))

100%|██████████| 63325/63325 [00:24<00:00, 2585.92it/s]
100%|██████████| 450/450 [00:00<00:00, 4657.13it/s]
100%|██████████| 450/450 [00:00<00:00, 4899.30it/s]


### DWT

In [14]:
# Discrete Wavelet Transform
def WaveletTransform(image):
    coeffs = pywt.dwt2(data=image, wavelet='haar', mode='symmetric', axes=(0, 1))
    # cA, (cH, cV, cD) = coeffs
    # _, (cH, cV, cD) = coeffs
    # cA, _ = coeffs
    # _, (cH, cV, _) = coeffs
    _, (cH, _, _) = coeffs
    return np.array([cH])

In [28]:
with mp.Pool(processes=NUM_PROCESSES) as pool:
    train_features = pool.map(WaveletTransform, tqdm(resized_train_imgs))
    val_features = pool.map(WaveletTransform, tqdm(resized_val_imgs))
    test_features = pool.map(WaveletTransform, tqdm(resized_test_imgs))

# train_features = np.array(train_features).squeeze()
# val_features = np.array(val_features).squeeze()
# test_features = np.array(test_features).squeeze()

train_features = np.array(train_features)
val_features = np.array(val_features)
test_features = np.array(test_features)

100%|██████████| 63325/63325 [00:07<00:00, 8090.36it/s] 
100%|██████████| 450/450 [00:00<00:00, 4965.20it/s]
100%|██████████| 450/450 [00:00<00:00, 12869.30it/s]


In [29]:
train_features.shape

(63325, 1, 64, 64)

### CNN

In [None]:
# dataloader

In [None]:
from turtle import forward


class CNN(nn.Module):
    def __init__(self) -> None:
        super(CNN, self).__init__()
        # 1
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=16,
                               kernel_size=3,
                               stride=1,
                               padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2,
                                  stride=2)
        self.batchnorm1 = nn.BatchNorm2d(num_features=16)
        # 2
        self.conv2 = nn.Conv2d(in_channels=16,
                               out_channels=32,
                               kernel_size=3,
                               stride=1,
                               padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2,
                                  stride=2)
        self.batchnorm2 = nn.BatchNorm2d(num_features=32)
        # 3
        self.conv3 = nn.Conv2d(in_channels=32,
                               out_channels=64,
                               kernel_size=3,
                               stride=1,
                               padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2,
                                  stride=2)
        self.batchnorm3 = nn.BatchNorm2d(num_features=64)
        # 4
        self.fc1 = nn.Linear(in_features=64,
                              out_features=128)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=128,
                              out_features=50)
        
    def forward(x):
        pass