Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[cuda ext] Protect Cuda Extension Loading #4779

Merged
merged 1 commit into from Sep 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 14 additions & 7 deletions parlai/ops/ngram_repeat_block.py
Expand Up @@ -9,6 +9,7 @@
"""
Wrapper for ngram_repeat_block cuda extension.
"""
import parlai.utils.logging as logging
import torch
from torch import nn

Expand All @@ -20,13 +21,19 @@
dname = os.path.dirname(abspath)
os.chdir(dname)

ngram_repeat_block_cuda = load(
name='ngram_repeat_block_cuda',
sources=[
'../clib/cuda/ngram_repeat_block_cuda.cpp',
'../clib/cuda/ngram_repeat_block_cuda_kernel.cu',
],
)

try:
ngram_repeat_block_cuda = load(
name='ngram_repeat_block_cuda',
sources=[
'../clib/cuda/ngram_repeat_block_cuda.cpp',
'../clib/cuda/ngram_repeat_block_cuda_kernel.cu',
],
)
except Exception as e:
logging.warning(f"Unable to load ngram blocking on GPU: {e}")
ngram_repeat_block_cuda = None

os.chdir(current)


Expand Down