Skip to content

Commit

Permalink
updated the model with xception 25 epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandrequ committed Mar 28, 2024
1 parent 040cae8 commit 98c9d30
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions arcjetCV/segmentation/contour/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,43 @@
class CNN:
def __init__(self):
# predict fails from GUI (Could not load library libcudnn_ops_infer.so), works from script. Fix: run it cpu only
self.device = torch.device("cpu") # torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

self.checkpoint_path = Path(__file__).parent.absolute().joinpath(Path('Unet-xception-last-checkpoint.pt'))
self.device = torch.device(
"cpu"
) # torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

self.checkpoint_path = (
Path(__file__)
.parent.absolute()
.joinpath(Path("Unet-xception_25_original.pt"))
)
self.model = torch.load(self.checkpoint_path, map_location=self.device)

self.mean = [0.485, 0.456, 0.406]
self.std = [0.229, 0.224, 0.225]
self.t = T.Compose([T.ToPILImage(), T.Resize((512, 512)), T.ToTensor(), T.Normalize(self.mean, self.std)])
self.t = T.Compose(
[
T.ToPILImage(),
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize(self.mean, self.std),
]
)

def predict(self, img):

self.model.eval()
image = self.t(img)

self.model.to(self.device)
image = image.to(self.device)

with torch.no_grad():
image = image.unsqueeze(0)
output = self.model(image)
masked = torch.argmax(output, dim=1)
masked = masked.cpu().squeeze(0).numpy()

masked = cv2.resize(masked, (img.shape[1], img.shape[0]), interpolation = cv2.INTER_NEAREST)

masked = cv2.resize(
masked, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST
)
return masked

0 comments on commit 98c9d30

Please sign in to comment.