In [None]:
import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.optim import lr_scheduler
from torchvision.transforms import transforms

from model import resnet50
from dataset import DatasetFolder
from dataset import DatasetFolder_wo_Label
from dataset import is_valid_file

import copy
from tqdm import tqdm
from glob import glob

# os.environ['CUDA_VISIBLE_DEVICES'] = "3"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True

# Model load

In [None]:
# model_ft = resnet50(pretrained=False, num_channels=1)
model_ft = resnet50(pretrained=False, num_channels=3)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
model_ft.load_state_dict(torch.load('./checkpoint/best.pt')) # 0: w/device, 1: wo/ device
model_ft.eval() 

optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.0005, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

# PNG Dataset load

In [None]:
data_transforms = {
    'test': transforms.Compose([
        transforms.Resize(1024),
        transforms.CenterCrop(1024),
        transforms.ToTensor(),
        transforms.Normalize([0.5,], [0.5,])
    ]),
}

test_dir = '/mnt/dataset/Synthesis_Study/2022/test'      
output_dir = '/mnt/dataset/Synthesis_Study/2022/2_X-ray_cleansing_png'

os.makedirs(output_dir, exist_ok=True)


test_dataset = DatasetFolder_wo_Label(test_dir, 
                                      transform = data_transforms['test'], 
                                      is_valid_file=is_valid_file)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=False, 
                                              num_workers=4)

dataset_sizes = len(test_dataset)
print(dataset_sizes)

In [None]:
model = copy.deepcopy(model_ft)
optimizer = copy.deepcopy(optimizer_ft)

for idx, inputs in enumerate(tqdm(test_dataloader)):
    
    src = os.path.join(test_dir, test_dataloader.dataset.samples[idx])
    dst = os.path.join(output_dir, os.path.split(test_dataloader.dataset.samples[idx])[-1])
    inputs = inputs.to(device)
    

    # zero the parameter gradients
    optimizer.zero_grad()

    with torch.set_grad_enabled(False):
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        
    pred = preds.cpu().numpy()[0]

    if pred == 1:
        shutil.copy(src, dst)
        continue