Skip to content

Commit

Permalink
fallback for suggested file name and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shyba committed Oct 12, 2022
1 parent d0aad8c commit 135e735
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
2 changes: 1 addition & 1 deletion lbry/extras/daemon/json_response_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def encode_file(self, managed_stream):
'streaming_url': managed_stream.stream_url,
'stream_hash': managed_stream.stream_hash,
'stream_name': managed_stream.descriptor.stream_name,
'suggested_file_name': managed_stream.descriptor.suggested_file_name,
'suggested_file_name': managed_stream.suggested_file_name,
'sd_hash': managed_stream.descriptor.sd_hash,
'mime_type': managed_stream.mime_type,
'key': managed_stream.descriptor.key,
Expand Down
18 changes: 13 additions & 5 deletions lbry/stream/managed_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,15 @@ def stream_hash(self) -> str:

@property
def file_name(self) -> Optional[str]:
return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None)
return self._file_name or self.suggested_file_name

@property
def suggested_file_name(self) -> Optional[str]:
if self.descriptor and self.descriptor.suggested_file_name and self.descriptor.suggested_file_name.strip():
return self.descriptor.suggested_file_name
elif self.stream_claim_info and self.stream_claim_info.claim:
return sanitize_file_name(self.stream_claim_info.claim.stream.source.name)
return "lbry_download" # default replacement for invalid name. Ideally we should never get here

@property
def written_bytes(self) -> int:
Expand Down Expand Up @@ -116,7 +124,7 @@ def blobs_remaining(self) -> int:

@property
def mime_type(self):
return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0]
return guess_media_type(os.path.basename(self.suggested_file_name))[0]

@property
def download_path(self):
Expand Down Expand Up @@ -162,7 +170,7 @@ async def start(self, timeout: Optional[float] = None,
if not self._file_name:
self._file_name = await get_next_available_file_name(
self.loop, self.download_directory,
self._file_name or sanitize_file_name(self.descriptor.suggested_file_name)
self._file_name or sanitize_file_name(self.suggested_file_name)
)
file_name, download_dir = self._file_name, self.download_directory
else:
Expand Down Expand Up @@ -294,14 +302,14 @@ async def save_file(self, file_name: Optional[str] = None, download_directory: O
self.download_directory = download_directory or self.download_directory or self.config.download_dir
if not self.download_directory:
raise ValueError("no directory to download to")
if not (file_name or self._file_name or self.descriptor.suggested_file_name):
if not (file_name or self._file_name or self.suggested_file_name):
raise ValueError("no file name to download to")
if not os.path.isdir(self.download_directory):
log.warning("download directory '%s' does not exist, attempting to make it", self.download_directory)
os.mkdir(self.download_directory)
self._file_name = await get_next_available_file_name(
self.loop, self.download_directory,
file_name or self._file_name or sanitize_file_name(self.descriptor.suggested_file_name)
file_name or self._file_name or sanitize_file_name(self.suggested_file_name)
)
await self.blob_manager.storage.change_file_download_dir_and_file_name(
self.stream_hash, self.download_directory, self.file_name
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/stream/test_managed_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from lbry.blob_exchange.server import BlobServerProtocol
from lbry.dht.node import Node
from lbry.dht.peer import make_kademlia_peer
from lbry.extras.daemon.storage import StoredContentClaim
from lbry.schema import Claim
from lbry.stream.managed_stream import ManagedStream
from lbry.stream.descriptor import StreamDescriptor
from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase
Expand All @@ -23,7 +25,10 @@ async def create_stream(self, blob_count: int = 10, file_name='test_file'):
with open(file_path, 'wb') as f:
f.write(self.stream_bytes)
descriptor = await StreamDescriptor.create_stream(self.loop, self.server_blob_manager.blob_dir, file_path)
self.sd_hash = descriptor.calculate_sd_hash()
descriptor.suggested_file_name = file_name
descriptor.stream_hash = descriptor.get_stream_hash()
self.sd_hash = descriptor.sd_hash = descriptor.calculate_sd_hash()
await descriptor.make_sd_blob()
return descriptor

async def setup_stream(self, blob_count: int = 10):
Expand All @@ -47,6 +52,19 @@ async def test_client_sanitizes_file_name(self):
self.assertEqual(self.stream.full_path, os.path.join(self.client_dir, 'tt_f'))
self.assertTrue(os.path.isfile(os.path.join(self.client_dir, 'tt_f')))

async def test_empty_name_fallback(self):
descriptor = await self.create_stream(file_name=" ")
descriptor.suggested_file_name = " "
claim = Claim()
claim.stream.source.name = "cool.mp4"
self.stream = ManagedStream(
self.loop, self.client_config, self.client_blob_manager, self.sd_hash, self.client_dir,
claim=StoredContentClaim(serialized=claim.to_bytes().hex())
)
await self._test_transfer_stream(10, skip_setup=True)
self.assertTrue(self.stream.completed)
self.assertEqual(self.stream.suggested_file_name, "cool.mp4")

async def test_status_file_completed(self):
await self._test_transfer_stream(10)
self.assertTrue(self.stream.output_file_exists)
Expand Down

0 comments on commit 135e735

Please sign in to comment.