-
Notifications
You must be signed in to change notification settings - Fork 356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
如何让faster-rcnn代码兼容maskrcnn的pytorch版本,并不用tf? #217
Comments
class MaskRcnnDataset(Dataset):
|
注意,from torch.nn import functional as F加上这一句话防止.F报错未定义 |
import torch class MaskRCNNTrainer(nn.Module):
|
class MaskRCNNTrainer(nn.Module):
|
class MaskRCNN(nn.Module): |
class MaskRCNN(nn.Module):
|
import torch def maskrcnn_dataset_collate(batch):
|
train.py文件修改: maskrcnn替换成下面的代码
|
这样就正常了 |
from nets.Maskrcnn import MaskRCNN |
import colorsys import numpy as np class MRCNN(object):
frcnn文件补出这个 |
import numpy as np def expand_boxes(boxes, scale):
def expand_masks(masks, boxes, image_shape, scale):
def random_colors(N, bright=True): def apply_mask(image, mask, color, alpha=0.5): def display_instances(image, boxes, masks, class_ids, class_names,
|
这个是utils_mask.py |
maskrcnn替换成下面的代码
|
frcnn = FRCNN() |
def fit_one_epoch_1(model, train_util, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir):
|
#fit_one_epoch(model, train_util, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir) |
masks = [Image.open(line[0][:-3]+"png").convert('1') for i in range(len(box))] |
masks = [Image.open((line[0][:-3]+"png").replace("JPEG","Segmentation")).convert('1') for i in range(len(box))] |
有代码,以下是maskrcnn在这个基代码的实现:
class MaskRCNNHead(nn.Module):
def init(self, n_class, roi_size, spatial_scale, num_convs=4, conv_dim=256, mask_out_dim=28):
super(MaskRCNNHead, self).init()
def roi_pooling(features, roi, roi_index, roi_size, spatial_scale):
# 假设roi_pooling是一个已经定义好的函数,用于从特征图中裁剪出对应ROI的部分
# 并进行池化操作以得到固定大小的输出
# 这里我们简化处理,直接使用ROIAlign
roi_align = torchvision.ops.RoIAlign((roi_size, roi_size), spatial_scale, sampling_ratio=2)
pooled_roi = roi_align(features, torch.tensor([roi], device=features.device))
return pooled_roi
还有就是,再建立一个文件maskrcnn.py:
import torch.nn as nn
from nets.classifier import Resnet50RoIHead, VGG16RoIHead, MaskRCNNHead
from nets.resnet50 import resnet50
from nets.rpn import RegionProposalNetwork
from nets.vgg16 import decom_vgg16
class MaskRCNN():
def init(self, num_classes,
mode = "training",
feat_stride = 16,
anchor_scales = [8, 16, 32],
ratios = [0.5, 1, 2],
backbone = 'vgg',
pretrained = False):
super(MaskRCNN, self).init()
self.feat_stride = feat_stride
#---------------------------------#
# 一共存在两个主干
# vgg和resnet50
#---------------------------------#
if backbone == 'vgg':
self.extractor, classifier = decom_vgg16(pretrained)
#---------------------------------#
# 构建建议框网络
#---------------------------------#
self.rpn = RegionProposalNetwork(
512, 512,
ratios = ratios,
anchor_scales = anchor_scales,
feat_stride = self.feat_stride,
mode = mode
)
#---------------------------------#
# 构建分类器网络
#---------------------------------#
self.head = VGG16RoIHead(
n_class = num_classes + 1,
roi_size = 7,
spatial_scale = 1,
classifier = classifier
)
self.mask_head = MaskRCNNHead(
n_class=num_classes + 1,
roi_size=14,
spatial_scale=1
)
elif backbone == 'resnet50':
self.extractor, classifier = resnet50(pretrained)
#---------------------------------#
# 构建classifier网络
#---------------------------------#
self.rpn = RegionProposalNetwork(
1024, 512,
ratios = ratios,
anchor_scales = anchor_scales,
feat_stride = self.feat_stride,
mode = mode
)
#---------------------------------#
# 构建classifier网络
#---------------------------------#
self.head = Resnet50RoIHead(
n_class = num_classes + 1,
roi_size = 14,
spatial_scale = 1,
classifier = classifier
)
self.mask_head = MaskRCNNHead(
n_class=num_classes + 1,
roi_size=14,
spatial_scale=1
)
即可完成maskrcnn的训练功能!
The text was updated successfully, but these errors were encountered: