In [1]:
import cv2
import os
import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from util.misc import AverageMeter
import torch.utils.data as data
from dataset import TotalText, Ctw1500Text, Icdar15Text, Mlt2017Text, TD500Text
from network.textnet import TextNet
from util.augmentation import BaseTransform,Augmentation
from cfglib.config import init_config, update_config, print_config
from cfglib.option import BaseOptions
from util.visualize import visualize_detection, visualize_gt
from util.misc import to_device, mkdirs,rescale_result
from util.eval import deal_eval_total_text, deal_eval_ctw1500, deal_eval_icdar15, \
    deal_eval_TD500, data_transfer_ICDAR, data_transfer_TD500, data_transfer_MLT2017
import sys
from network.loss import  TextLoss

sys.argv=['']



In [2]:
option = BaseOptions()
args = option.initialize()
args

Namespace(batch_size=12, checkepoch=590, cls_threshold=0.875, cuda=True, dis_threshold=0.3, display_freq=10, exp_name='Totaltext', gamma=0.1, gpu='1', img_root=None, input_size=640, log_dir='./logs/', log_freq=10000, loss='CrossEntropyLoss', lr=0.001, lr_adjust='fix', max_epoch=200, means=(0.485, 0.456, 0.406), mgpu=False, momentum=0.9, net='resnet50', num_workers=8, optim='Adam', pretrain=False, rescale=255.0, resume=None, save_dir='./model/', save_freq=5, start_epoch=0, stds=(0.229, 0.224, 0.225), stepvalues=[], test_size=[640, 1024], val_freq=1000, verbose=True, vis_dir='./vis/', viz=False, viz_freq=50, weight_decay=0.0)

In [2]:
cfg = init_config()
option = BaseOptions()
args = option.initialize()
update_config(cfg, args)
print_config(cfg)

# Create checkpoint directory
if not os.path.exists(cfg.save_path):
    mkdirs(cfg.save_path)

# Create the model
model = TextNet(is_training=True, backbone=cfg.net,)
model.train()
# Initialize wandb
if cfg.wandb_flag:
    global wandb
    import wandb
    wandb.init(project=cfg.wandb_project, name=cfg.wandb_name, config=cfg)
    wandb.watch(model)

# Load the dataset
if cfg.dataset == 'TotalText':
    train_dataset = TotalText(cfg.train_data_root, cfg.train_data_list,
                            transform=Augmentation(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
    val_dataset = TotalText(cfg.val_data_root, cfg.val_data_list,
                            transform=BaseTransform(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
elif cfg.dataset == 'CTW1500':
    train_dataset = Ctw1500Text(cfg.train_data_root, cfg.train_data_list,
                                transform=Augmentation(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
    val_dataset = Ctw1500Text(cfg.val_data_root, cfg.val_data_list,
                            transform=BaseTransform(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
elif cfg.dataset == 'ICDAR15':
    train_dataset = Icdar15Text(cfg.train_data_root, cfg.train_data_list,
                                transform=Augmentation(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
    val_dataset = Icdar15Text(cfg.val_data_root, cfg.val_data_list,
                            transform=BaseTransform(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
elif cfg.dataset == 'MLT2017':
    train_dataset = Mlt2017Text(cfg.train_data_root, cfg.train_data_list,
                                transform=Augmentation(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
    val_dataset = Mlt2017Text(cfg.val_data_root, cfg.val_data_list,
                            transform=BaseTransform(cfg.input_size, cfg.rgb_mean, cfg.rgb_std, cfg.inter_type))
elif cfg.dataset == 'TD500':
    train_dataset = TD500Text(cfg.train_data_root, is_training=True,
                            transform=Augmentation(cfg.train_input_size, cfg.train_rgb_mean, cfg.train_rgb_std))
    val_dataset = TD500Text(cfg.val_data_root, is_training=False,
                            transform=BaseTransform(cfg.val_input_size, cfg.val_rgb_mean, cfg.val_rgb_std))
else:
    raise NotImplementedError

# Create the dataloader
train_loader = data.DataLoader(train_dataset, batch_size=cfg.train_batch_size,
                                shuffle=cfg.train_shuffle, num_workers=cfg.train_num_workers) #, pin_memory=True)
val_loader = data.DataLoader(val_dataset, batch_size=cfg.val_batch_size,
                                shuffle=cfg.val_shuffle, num_workers=cfg.val_num_workers) #, pin_memory=True)

# Create the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train_lr)
#optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=cfg.momentum,
#                            weight_decay=cfg.weight_decay, nesterov=cfg.nesterov)

# Create the learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.train_step_size, gamma=cfg.train_gamma)

# Create the loss criterion
criterion = TextLoss()

# Load the pretrained model
if cfg.pretrain:
    print('Loading pretrained model from {}'.format(cfg.pretrain_model))
    model.load_state_dict(torch.load(cfg.pretrain_model))

# Move the model to GPU
if cfg.use_gpu:
    model = model.cuda()
    criterion = criterion.cuda()


adj_num: 4
approx_factor: 0.007
batch_size: 12
checkepoch: 590
cls_threshold: 0.875
cuda: True
dataset: TD500
device: cuda
dis_threshold: 0.3
display_freq: 10
epochs: 200
exp_name: Totaltext
gamma: 0.1
global_checkepoch: -1
global_dataset: TD500
global_epochs: 200
global_gpu: [0]
global_output_idr: output
global_pretrain: False
global_pretrain_model: None
global_print_freq: 2
global_save_freq: 10
global_save_path: ./checkpoints/
global_use_gpu: False
global_val_freq: 1
gpu: 1
grad_clip: 0
img_root: None
input_size: 640
log_dir: ./logs/
log_freq: 10000
loss: CrossEntropyLoss
lr: 0.001
lr_adjust: fix
max_annotation: 64
max_epoch: 200
max_points: 20
means: [0.485, 0.456, 0.406]
mgpu: False
momentum: 0.9
net: resnet50
num_points: 20
num_workers: 8
optim: Adam
output_dir: output
output_idr: output
pretrain: False
pretrain_model: None
print_freq: 2
rescale: 255.0
resume: None
save_dir: ./model/
save_freq: 5
save_path: ./checkpoints/
scale: 1
start_epoch: 0
stds: [0.229, 0.224, 0.225]
stepval

In [3]:
cfg.val_input_size

512

In [9]:
val_dataset = TD500Text(
    cfg.val_data_root, 
    is_training=False,
    transform=BaseTransform(cfg.test_size, cfg.val_rgb_mean, cfg.val_rgb_std))

In [10]:
test_loader = data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=cfg.num_workers)
for i, (image, meta) in enumerate(test_loader):
    if 'ignore_tags' in meta:
        print("JDHCAGDYG")

In [11]:
len(test_loader)

200

In [7]:
test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg.num_workers)

In [8]:
for i, (image, meta) in enumerate(test_loader):
    pass