diff --git a/src/Api/Api.csproj b/src/Api/Api.csproj
index 5ee677cbeef8..1373497c7128 100644
--- a/src/Api/Api.csproj
+++ b/src/Api/Api.csproj
@@ -26,6 +26,7 @@
+
diff --git a/src/Api/Controllers/SendsController.cs b/src/Api/Controllers/SendsController.cs
index 511b76a92d56..d67c672ff57d 100644
--- a/src/Api/Controllers/SendsController.cs
+++ b/src/Api/Controllers/SendsController.cs
@@ -7,11 +7,13 @@
using Bit.Core.Models.Api;
using Bit.Core.Exceptions;
using Bit.Core.Services;
-using Bit.Api.Utilities;
-using Bit.Core.Models.Table;
using Bit.Core.Utilities;
using Bit.Core.Settings;
using Bit.Core.Models.Api.Response;
+using Bit.Core.Enums;
+using Microsoft.Azure.EventGrid.Models;
+using Bit.Api.Utilities;
+using System.Collections.Generic;
namespace Bit.Api.Controllers
{
@@ -69,14 +71,40 @@ public async Task Access(string id, [FromBody] SendAccessRequestM
}
[AllowAnonymous]
- [HttpGet("access/file/{id}")]
- public async Task GetSendFileDownloadData(string id)
+ [HttpPost("{encodedSendId}/access/file/{fileId}")]
+ public async Task GetSendFileDownloadData(string encodedSendId,
+ string fileId, [FromBody] SendAccessRequestModel model)
{
- return new SendFileDownloadDataResponseModel()
+ var sendId = new Guid(CoreHelpers.Base64UrlDecode(encodedSendId));
+ var send = await _sendRepository.GetByIdAsync(sendId);
+
+ if (send == null)
{
- Id = id,
- Url = await _sendFileStorageService.GetSendFileDownloadUrlAsync(id),
- };
+ throw new BadRequestException("Could not locate send");
+ }
+
+ var (url, passwordRequired, passwordInvalid) = await _sendService.GetSendFileDownloadUrlAsync(send, fileId,
+ model.Password);
+
+ if (passwordRequired)
+ {
+ return new UnauthorizedResult();
+ }
+ if (passwordInvalid)
+ {
+ await Task.Delay(2000);
+ throw new BadRequestException("Invalid password.");
+ }
+ if (send == null)
+ {
+ throw new NotFoundException();
+ }
+
+ return new ObjectResult(new SendFileDownloadDataResponseModel()
+ {
+ Id = fileId,
+ Url = url,
+ });
}
[HttpGet("{id}")]
@@ -112,31 +140,77 @@ public async Task Post([FromBody] SendRequestModel model)
}
[HttpPost("file")]
- [RequestSizeLimit(105_906_176)]
+ public async Task PostFile([FromBody] SendRequestModel model)
+ {
+ if (model.Type != SendType.File)
+ {
+ throw new BadRequestException("Invalid content.");
+ }
+
+ if (!model.FileLength.HasValue)
+ {
+ throw new BadRequestException("Invalid content. File size hint is required.");
+ }
+
+ var userId = _userService.GetProperUserId(User).Value;
+ var (send, data) = model.ToSend(userId, model.File.FileName, _sendService);
+ var uploadUrl = await _sendService.SaveFileSendAsync(send, data, model.FileLength.Value);
+ return new SendFileUploadDataResponseModel
+ {
+ Url = uploadUrl,
+ FileUploadType = _sendFileStorageService.FileUploadType,
+ SendResponse = new SendResponseModel(send, _globalSettings)
+ };
+ }
+
+ [HttpPost("{id}/file/{fileId}")]
[DisableFormValueModelBinding]
- public async Task PostFile()
+ public async Task PostFileForExistingSend(string id, string fileId)
{
if (!Request?.ContentType.Contains("multipart/") ?? true)
{
throw new BadRequestException("Invalid content.");
}
- if (Request.ContentLength > 105906176) // 101 MB, give em' 1 extra MB for cushion
+ if (Request.ContentLength > 105906176 && !_globalSettings.SelfHosted) // 101 MB, give em' 1 extra MB for cushion
{
- throw new BadRequestException("Max file size is 100 MB.");
+ throw new BadRequestException("Max file size for direct upload is 100 MB.");
}
- Send send = null;
- await Request.GetSendFileAsync(async (stream, fileName, model) =>
+ var send = await _sendRepository.GetByIdAsync(new Guid(id));
+ await Request.GetSendFileAsync(async (stream) =>
{
- model.ValidateCreation();
- var userId = _userService.GetProperUserId(User).Value;
- var (madeSend, madeData) = model.ToSend(userId, fileName, _sendService);
- send = madeSend;
- await _sendService.CreateSendAsync(send, madeData, stream, Request.ContentLength.GetValueOrDefault(0));
+ await _sendService.UploadFileToExistingSendAsync(stream, send);
});
+ }
- return new SendResponseModel(send, _globalSettings);
+ [AllowAnonymous]
+ [HttpPost("file/validate/azure")]
+ public async Task AzureValidateFile()
+ {
+ return await ApiHelpers.HandleAzureEvents(Request, new Dictionary>
+ {
+ {
+ "Microsoft.Storage.BlobCreated", async (eventGridEvent) =>
+ {
+ try
+ {
+ var blobName = eventGridEvent.Subject.Split($"{AzureSendFileStorageService.FilesContainerName}/blobs/")[1];
+ var sendId = AzureSendFileStorageService.SendIdFromBlobName(blobName);
+ var send = await _sendRepository.GetByIdAsync(new Guid(sendId));
+ if (send == null)
+ {
+ return;
+ }
+ await _sendService.ValidateSendFile(send);
+ }
+ catch
+ {
+ return;
+ }
+ }
+ }
+ });
}
[HttpPut("{id}")]
diff --git a/src/Api/Utilities/ApiHelpers.cs b/src/Api/Utilities/ApiHelpers.cs
index 8aef098b52b5..20ed178762fe 100644
--- a/src/Api/Utilities/ApiHelpers.cs
+++ b/src/Api/Utilities/ApiHelpers.cs
@@ -1,5 +1,10 @@
using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Mvc;
+using Microsoft.Azure.EventGrid;
+using Microsoft.Azure.EventGrid.Models;
using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
@@ -29,5 +34,47 @@ public async static Task ReadJsonFileFromBody(HttpContext httpContext, IFo
return obj;
}
+
+ ///
+ /// Validates Azure event subscription and calls the appropriate event handler. Responds HttpOk.
+ ///
+ /// HttpRequest received from Azure
+ /// Dictionary of eventType strings and their associated handlers.
+ /// OkObjectResult
+ /// Reference https://docs.microsoft.com/en-us/azure/event-grid/receive-events
+ public async static Task HandleAzureEvents(HttpRequest request,
+ Dictionary> eventTypeHandlers)
+ {
+ var response = string.Empty;
+ var requestContent = await new StreamReader(request.Body).ReadToEndAsync();
+ if (string.IsNullOrWhiteSpace(requestContent))
+ {
+ return new OkObjectResult(response);
+ }
+
+ var eventGridSubscriber = new EventGridSubscriber();
+ var eventGridEvents = eventGridSubscriber.DeserializeEventGridEvents(requestContent);
+
+ foreach (var eventGridEvent in eventGridEvents)
+ {
+ if (eventGridEvent.Data is SubscriptionValidationEventData eventData)
+ {
+ // Might want to enable additional validation: subject, topic etc.
+
+ var responseData = new SubscriptionValidationResponse()
+ {
+ ValidationResponse = eventData.ValidationCode
+ };
+
+ return new OkObjectResult(responseData);
+ }
+ else if (eventTypeHandlers.ContainsKey(eventGridEvent.EventType))
+ {
+ await eventTypeHandlers[eventGridEvent.EventType](eventGridEvent);
+ }
+ }
+
+ return new OkObjectResult(response);
+ }
}
}
diff --git a/src/Api/Utilities/MultipartFormDataHelper.cs b/src/Api/Utilities/MultipartFormDataHelper.cs
index 03ed0f1ae722..261b8992f3b5 100644
--- a/src/Api/Utilities/MultipartFormDataHelper.cs
+++ b/src/Api/Utilities/MultipartFormDataHelper.cs
@@ -66,45 +66,24 @@ public static async Task GetFileAsync(this HttpRequest request, Func callback)
+ public static async Task GetSendFileAsync(this HttpRequest request, Func callback)
{
var boundary = GetBoundary(MediaTypeHeaderValue.Parse(request.ContentType),
_defaultFormOptions.MultipartBoundaryLengthLimit);
var reader = new MultipartReader(boundary, request.Body);
- var firstSection = await reader.ReadNextSectionAsync();
- if (firstSection != null)
+ var dataSection = await reader.ReadNextSectionAsync();
+ if (dataSection != null)
{
- if (ContentDispositionHeaderValue.TryParse(firstSection.ContentDisposition, out _))
+ if (ContentDispositionHeaderValue.TryParse(dataSection.ContentDisposition, out var dataContent)
+ && HasFileContentDisposition(dataContent))
{
- // Request model json, then data
- string requestModelJson = null;
- using (var sr = new StreamReader(firstSection.Body))
+ using (dataSection.Body)
{
- requestModelJson = await sr.ReadToEndAsync();
+ await callback(dataSection.Body);
}
-
- var secondSection = await reader.ReadNextSectionAsync();
- if (secondSection != null)
- {
- if (ContentDispositionHeaderValue.TryParse(secondSection.ContentDisposition,
- out var secondContent) && HasFileContentDisposition(secondContent))
- {
- var fileName = HeaderUtilities.RemoveQuotes(secondContent.FileName).ToString();
- using (secondSection.Body)
- {
- var model = JsonConvert.DeserializeObject(requestModelJson);
- await callback(secondSection.Body, fileName, model);
- }
- }
-
- secondSection = null;
- }
-
}
-
- firstSection = null;
+ dataSection = null;
}
}
diff --git a/src/Core/Enums/FileUploadType.cs b/src/Core/Enums/FileUploadType.cs
new file mode 100644
index 000000000000..dc50eb669620
--- /dev/null
+++ b/src/Core/Enums/FileUploadType.cs
@@ -0,0 +1,8 @@
+namespace Bit.Core.Enums
+{
+ public enum FileUploadType
+ {
+ Direct = 0,
+ Azure = 1,
+ }
+}
diff --git a/src/Core/Models/Api/Request/SendRequestModel.cs b/src/Core/Models/Api/Request/SendRequestModel.cs
index ccb27b0cc56b..d64faef176a8 100644
--- a/src/Core/Models/Api/Request/SendRequestModel.cs
+++ b/src/Core/Models/Api/Request/SendRequestModel.cs
@@ -13,6 +13,7 @@ namespace Bit.Core.Models.Api
public class SendRequestModel
{
public SendType Type { get; set; }
+ public long? FileLength { get; set; } = null;
[EncryptedString]
[EncryptedStringLength(1000)]
public string Name { get; set; }
diff --git a/src/Core/Models/Api/Response/SendFileUploadDataResponseModel.cs b/src/Core/Models/Api/Response/SendFileUploadDataResponseModel.cs
new file mode 100644
index 000000000000..7e4b95ddbf30
--- /dev/null
+++ b/src/Core/Models/Api/Response/SendFileUploadDataResponseModel.cs
@@ -0,0 +1,13 @@
+using Bit.Core.Enums;
+
+namespace Bit.Core.Models.Api.Response
+{
+ public class SendFileUploadDataResponseModel : ResponseModel
+ {
+ public string Url { get; set; }
+ public FileUploadType FileUploadType { get; set; }
+ public SendResponseModel SendResponse { get; set; }
+
+ public SendFileUploadDataResponseModel() : base("send-fileUpload") { }
+ }
+}
diff --git a/src/Core/Services/ISendService.cs b/src/Core/Services/ISendService.cs
index 6b44a3824f00..b4ea58971d98 100644
--- a/src/Core/Services/ISendService.cs
+++ b/src/Core/Services/ISendService.cs
@@ -10,8 +10,11 @@ public interface ISendService
{
Task DeleteSendAsync(Send send);
Task SaveSendAsync(Send send);
- Task CreateSendAsync(Send send, SendFileData data, Stream stream, long requestLength);
+ Task SaveFileSendAsync(Send send, SendFileData data, long fileLength);
+ Task UploadFileToExistingSendAsync(Stream stream, Send send);
Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password);
string HashPassword(string password);
+ Task ValidateSendFile(Send send);
+ Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password);
}
}
diff --git a/src/Core/Services/ISendStorageService.cs b/src/Core/Services/ISendStorageService.cs
index 1b4c5cc3718c..91829070e623 100644
--- a/src/Core/Services/ISendStorageService.cs
+++ b/src/Core/Services/ISendStorageService.cs
@@ -1,4 +1,5 @@
-using Bit.Core.Models.Table;
+using Bit.Core.Enums;
+using Bit.Core.Models.Table;
using System;
using System.IO;
using System.Threading.Tasks;
@@ -7,10 +8,13 @@ namespace Bit.Core.Services
{
public interface ISendFileStorageService
{
+ FileUploadType FileUploadType { get; }
Task UploadNewFileAsync(Stream stream, Send send, string fileId);
- Task DeleteFileAsync(string fileId);
+ Task DeleteFileAsync(Send send, string fileId);
Task DeleteFilesForOrganizationAsync(Guid organizationId);
Task DeleteFilesForUserAsync(Guid userId);
- Task GetSendFileDownloadUrlAsync(string fileId);
+ Task GetSendFileDownloadUrlAsync(Send send, string fileId);
+ Task GetSendFileUploadUrlAsync(Send send, string fileId);
+ Task ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway);
}
}
diff --git a/src/Core/Services/Implementations/AzureSendFileStorageService.cs b/src/Core/Services/Implementations/AzureSendFileStorageService.cs
index 50edc58cb3fa..2b646764669d 100644
--- a/src/Core/Services/Implementations/AzureSendFileStorageService.cs
+++ b/src/Core/Services/Implementations/AzureSendFileStorageService.cs
@@ -5,17 +5,23 @@
using System;
using Bit.Core.Models.Table;
using Bit.Core.Settings;
+using Bit.Core.Enums;
namespace Bit.Core.Services
{
public class AzureSendFileStorageService : ISendFileStorageService
{
- private const string FilesContainerName = "sendfiles";
+ public const string FilesContainerName = "sendfiles";
private static readonly TimeSpan _downloadLinkLiveTime = TimeSpan.FromMinutes(1);
private readonly CloudBlobClient _blobClient;
private CloudBlobContainer _sendFilesContainer;
+ public FileUploadType FileUploadType => FileUploadType.Azure;
+
+ public static string SendIdFromBlobName(string blobName) => blobName.Split('/')[0];
+ public static string BlobName(Send send, string fileId) => $"{send.Id}/{fileId}";
+
public AzureSendFileStorageService(
GlobalSettings globalSettings)
{
@@ -26,7 +32,7 @@ public AzureSendFileStorageService(
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{
await InitAsync();
- var blob = _sendFilesContainer.GetBlockBlobReference(fileId);
+ var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
if (send.UserId.HasValue)
{
blob.Metadata.Add("userId", send.UserId.Value.ToString());
@@ -39,10 +45,10 @@ public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
await blob.UploadFromStreamAsync(stream);
}
- public async Task DeleteFileAsync(string fileId)
+ public async Task DeleteFileAsync(Send send, string fileId)
{
await InitAsync();
- var blob = _sendFilesContainer.GetBlockBlobReference(fileId);
+ var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
await blob.DeleteIfExistsAsync();
}
@@ -56,19 +62,67 @@ public async Task DeleteFilesForUserAsync(Guid userId)
await InitAsync();
}
- public async Task GetSendFileDownloadUrlAsync(string fileId)
+ public async Task GetSendFileDownloadUrlAsync(Send send, string fileId)
{
await InitAsync();
- var blob = _sendFilesContainer.GetBlockBlobReference(fileId);
+ var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
var accessPolicy = new SharedAccessBlobPolicy()
{
SharedAccessExpiryTime = DateTime.UtcNow.Add(_downloadLinkLiveTime),
- Permissions = SharedAccessBlobPermissions.Read
+ Permissions = SharedAccessBlobPermissions.Read,
};
return blob.Uri + blob.GetSharedAccessSignature(accessPolicy);
}
+ public async Task GetSendFileUploadUrlAsync(Send send, string fileId)
+ {
+ await InitAsync();
+ var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
+
+ var accessPolicy = new SharedAccessBlobPolicy()
+ {
+ SharedAccessExpiryTime = DateTime.UtcNow.Add(_downloadLinkLiveTime),
+ Permissions = SharedAccessBlobPermissions.Create,
+ };
+
+ return blob.Uri + blob.GetSharedAccessSignature(accessPolicy);
+ }
+
+ public async Task ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway)
+ {
+ await InitAsync();
+
+ var blob = _sendFilesContainer.GetBlockBlobReference(BlobName(send, fileId));
+
+ if (!blob.Exists())
+ {
+ return false;
+ }
+
+ blob.FetchAttributes();
+
+ if (send.UserId.HasValue)
+ {
+ blob.Metadata["userId"] = send.UserId.Value.ToString();
+ }
+ else
+ {
+ blob.Metadata["organizationId"] = send.OrganizationId.Value.ToString();
+ }
+ blob.Properties.ContentDisposition = $"attachment; filename=\"{fileId}\"";
+ blob.SetMetadata();
+ blob.SetProperties();
+
+ var length = blob.Properties.Length;
+ if (length < expectedFileSize - leeway || length > expectedFileSize + leeway)
+ {
+ return false;
+ }
+
+ return true;
+ }
+
private async Task InitAsync()
{
if (_sendFilesContainer == null)
diff --git a/src/Core/Services/Implementations/LocalSendStorageService.cs b/src/Core/Services/Implementations/LocalSendStorageService.cs
index a015149e8a29..ae4c20cf2ef8 100644
--- a/src/Core/Services/Implementations/LocalSendStorageService.cs
+++ b/src/Core/Services/Implementations/LocalSendStorageService.cs
@@ -3,6 +3,8 @@
using System;
using Bit.Core.Models.Table;
using Bit.Core.Settings;
+using Bit.Core.Enums;
+using System.Linq;
namespace Bit.Core.Services
{
@@ -11,6 +13,10 @@ public class LocalSendStorageService : ISendFileStorageService
private readonly string _baseDirPath;
private readonly string _baseSendUrl;
+ public FileUploadType FileUploadType => FileUploadType.Direct;
+ private string RelativeFilePath(Send send, string fileID) => $"{send.Id}/{fileID}";
+ private string FilePath(Send send, string fileID) => $"{_baseDirPath}/{RelativeFilePath(send, fileID)}";
+
public LocalSendStorageService(
GlobalSettings globalSettings)
{
@@ -21,17 +27,21 @@ public LocalSendStorageService(
public async Task UploadNewFileAsync(Stream stream, Send send, string fileId)
{
await InitAsync();
- using (var fs = File.Create($"{_baseDirPath}/{fileId}"))
+ var path = FilePath(send, fileId);
+ Directory.CreateDirectory(Path.GetDirectoryName(path));
+ using (var fs = File.Create(path))
{
stream.Seek(0, SeekOrigin.Begin);
await stream.CopyToAsync(fs);
}
}
- public async Task DeleteFileAsync(string fileId)
+ public async Task DeleteFileAsync(Send send, string fileId)
{
await InitAsync();
- DeleteFileIfExists($"{_baseDirPath}/{fileId}");
+ var path = FilePath(send, fileId);
+ DeleteFileIfExists(path);
+ DeleteDirectoryIfExistsAndEmpty(Path.GetDirectoryName(path));
}
public async Task DeleteFilesForOrganizationAsync(Guid organizationId)
@@ -44,10 +54,10 @@ public async Task DeleteFilesForUserAsync(Guid userId)
await InitAsync();
}
- public async Task GetSendFileDownloadUrlAsync(string fileId)
+ public async Task GetSendFileDownloadUrlAsync(Send send, string fileId)
{
await InitAsync();
- return $"{_baseSendUrl}/{fileId}";
+ return $"{_baseSendUrl}/{RelativeFilePath(send, fileId)}";
}
private void DeleteFileIfExists(string path)
@@ -58,6 +68,14 @@ private void DeleteFileIfExists(string path)
}
}
+ private void DeleteDirectoryIfExistsAndEmpty(string path)
+ {
+ if (Directory.Exists(path) && !Directory.EnumerateFiles(path).Any())
+ {
+ Directory.Delete(path);
+ }
+ }
+
private Task InitAsync()
{
if (!Directory.Exists(_baseDirPath))
@@ -67,5 +85,25 @@ private Task InitAsync()
return Task.FromResult(0);
}
+
+ public Task GetSendFileUploadUrlAsync(Send send, string fileId)
+ => Task.FromResult($"/sends/{send.Id}/file/{fileId}");
+
+ public Task ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway)
+ {
+ var path = FilePath(send, fileId);
+ if (!File.Exists(path))
+ {
+ return Task.FromResult(false);
+ }
+
+ var fileInfo = new FileInfo(path);
+ if (expectedFileSize < fileInfo.Length - leeway || expectedFileSize > fileInfo.Length + leeway)
+ {
+ return Task.FromResult(false);
+ }
+
+ return Task.FromResult(true);
+ }
}
}
diff --git a/src/Core/Services/Implementations/SendService.cs b/src/Core/Services/Implementations/SendService.cs
index 8342c8bf7024..cb0fa4cf2d8b 100644
--- a/src/Core/Services/Implementations/SendService.cs
+++ b/src/Core/Services/Implementations/SendService.cs
@@ -28,6 +28,7 @@ public class SendService : ISendService
private readonly IReferenceEventService _referenceEventService;
private readonly GlobalSettings _globalSettings;
private readonly ICurrentContext _currentContext;
+ private const long _fileSizeLeeway = 1024L * 1024L; // 1MB
public SendService(
ISendRepository sendRepository,
@@ -74,101 +75,107 @@ public async Task SaveSendAsync(Send send)
}
}
- public async Task CreateSendAsync(Send send, SendFileData data, Stream stream, long requestLength)
+ public async Task SaveFileSendAsync(Send send, SendFileData data, long fileLength)
{
if (send.Type != SendType.File)
{
throw new BadRequestException("Send is not of type \"file\".");
}
- if (requestLength < 1)
+ if (fileLength < 1)
{
throw new BadRequestException("No file data.");
}
- var storageBytesRemaining = 0L;
- if (send.UserId.HasValue)
- {
- var user = await _userRepository.GetByIdAsync(send.UserId.Value);
- if (!(await _userService.CanAccessPremium(user)))
- {
- throw new BadRequestException("You must have premium status to use file sends.");
- }
-
- if (user.Premium)
- {
- storageBytesRemaining = user.StorageBytesRemaining();
- }
- else
- {
- // Users that get access to file storage/premium from their organization get the default
- // 1 GB max storage.
- storageBytesRemaining = user.StorageBytesRemaining(
- _globalSettings.SelfHosted ? (short)10240 : (short)1);
- }
- }
- else if (send.OrganizationId.HasValue)
- {
- var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
- if (!org.MaxStorageGb.HasValue)
- {
- throw new BadRequestException("This organization cannot use file sends.");
- }
-
- storageBytesRemaining = org.StorageBytesRemaining();
- }
+ var storageBytesRemaining = await StorageRemainingForSendAsync(send);
- if (storageBytesRemaining < requestLength)
+ if (storageBytesRemaining < fileLength)
{
throw new BadRequestException("Not enough storage available.");
}
var fileId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false);
- await _sendFileStorageService.UploadNewFileAsync(stream, send, fileId);
try
{
data.Id = fileId;
- data.Size = stream.Length;
+ data.Size = fileLength;
send.Data = JsonConvert.SerializeObject(data,
new JsonSerializerSettings { NullValueHandling = NullValueHandling.Ignore });
await SaveSendAsync(send);
+ return await _sendFileStorageService.GetSendFileUploadUrlAsync(send, fileId);
}
catch
{
// Clean up since this is not transactional
- await _sendFileStorageService.DeleteFileAsync(fileId);
+ await _sendFileStorageService.DeleteFileAsync(send, fileId);
throw;
}
}
+ public async Task UploadFileToExistingSendAsync(Stream stream, Send send)
+ {
+ if (send?.Data == null)
+ {
+ throw new BadRequestException("Send does not have file data");
+ }
+
+ if (send.Type != SendType.File)
+ {
+ throw new BadRequestException("Not a File Type Send.");
+ }
+
+ var data = JsonConvert.DeserializeObject(send.Data);
+
+ await _sendFileStorageService.UploadNewFileAsync(stream, send, data.Id);
+
+ if (!await ValidateSendFile(send))
+ {
+ throw new BadRequestException("File received does not match expected file length.");
+ }
+ }
+
+ public async Task ValidateSendFile(Send send)
+ {
+ var fileData = JsonConvert.DeserializeObject(send.Data);
+
+ var valid = await _sendFileStorageService.ValidateFileAsync(send, fileData.Id, fileData.Size, _fileSizeLeeway);
+
+ if (!valid)
+ {
+ // File reported differs in size from that promised. Must be a rogue client. Delete Send
+ await DeleteSendAsync(send);
+ }
+
+ return valid;
+ }
+
public async Task DeleteSendAsync(Send send)
{
await _sendRepository.DeleteAsync(send);
if (send.Type == Enums.SendType.File)
{
var data = JsonConvert.DeserializeObject(send.Data);
- await _sendFileStorageService.DeleteFileAsync(data.Id);
+ await _sendFileStorageService.DeleteFileAsync(send, data.Id);
}
await _pushService.PushSyncSendDeleteAsync(send);
}
- // Response: Send, password required, password invalid
- public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
+ public (bool grant, bool passwordRequiredError, bool passwordInvalidError) SendCanBeAccessed(Send send,
+ string password)
{
- var send = await _sendRepository.GetByIdAsync(sendId);
var now = DateTime.UtcNow;
if (send == null || send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount ||
send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || send.Disabled ||
send.DeletionDate < now)
{
- return (null, false, false);
+ return (false, false, false);
}
if (!string.IsNullOrWhiteSpace(send.Password))
{
if (string.IsNullOrWhiteSpace(password))
{
- return (null, true, false);
+ return (false, true, false);
}
var passwordResult = _passwordHasher.VerifyHashedPassword(new User(), send.Password, password);
if (passwordResult == PasswordVerificationResult.SuccessRehashNeeded)
@@ -177,11 +184,51 @@ public async Task DeleteSendAsync(Send send)
}
if (passwordResult == PasswordVerificationResult.Failed)
{
- return (null, false, true);
+ return (false, false, true);
}
}
- // TODO: maybe move this to a simple ++ sproc?
+
+ return (true, false, false);
+ }
+
+ // Response: Send, password required, password invalid
+ public async Task<(string, bool, bool)> GetSendFileDownloadUrlAsync(Send send, string fileId, string password)
+ {
+ if (send.Type != SendType.File)
+ {
+ throw new BadRequestException("Can only get a download URL for a file type of Send");
+ }
+
+ var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
+
+ if (!grantAccess)
+ {
+ return (null, passwordRequired, passwordInvalid);
+ }
+
send.AccessCount++;
+ await _sendRepository.ReplaceAsync(send);
+ return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), false, false);
+ }
+
+ // Response: Send, password required, password invalid
+ public async Task<(Send, bool, bool)> AccessAsync(Guid sendId, string password)
+ {
+ var send = await _sendRepository.GetByIdAsync(sendId);
+ var (grantAccess, passwordRequired, passwordInvalid) = SendCanBeAccessed(send, password);
+
+ if (!grantAccess)
+ {
+ return (null, passwordRequired, passwordInvalid);
+ }
+
+ // TODO: maybe move this to a simple ++ sproc?
+ if (send.Type != SendType.File)
+ {
+ // File sends are incremented during file download
+ send.AccessCount++;
+ }
+
await _sendRepository.ReplaceAsync(send);
await RaiseReferenceEventAsync(send, ReferenceEventType.SendAccessed);
return (send, false, false);
@@ -227,5 +274,42 @@ private async Task ValidateUserCanSaveAsync(Guid? userId)
}
}
}
+
+ private async Task StorageRemainingForSendAsync(Send send)
+ {
+ var storageBytesRemaining = 0L;
+ if (send.UserId.HasValue)
+ {
+ var user = await _userRepository.GetByIdAsync(send.UserId.Value);
+ if (!await _userService.CanAccessPremium(user))
+ {
+ throw new BadRequestException("You must have premium status to use file sends.");
+ }
+
+ if (user.Premium)
+ {
+ storageBytesRemaining = user.StorageBytesRemaining();
+ }
+ else
+ {
+ // Users that get access to file storage/premium from their organization get the default
+ // 1 GB max storage.
+ storageBytesRemaining = user.StorageBytesRemaining(
+ _globalSettings.SelfHosted ? (short)10240 : (short)1);
+ }
+ }
+ else if (send.OrganizationId.HasValue)
+ {
+ var org = await _organizationRepository.GetByIdAsync(send.OrganizationId.Value);
+ if (!org.MaxStorageGb.HasValue)
+ {
+ throw new BadRequestException("This organization cannot use file sends.");
+ }
+
+ storageBytesRemaining = org.StorageBytesRemaining();
+ }
+
+ return storageBytesRemaining;
+ }
}
}
diff --git a/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs b/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs
index 2b174ce36fbb..b0d694dcdb68 100644
--- a/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs
+++ b/src/Core/Services/NoopImplementations/NoopSendFileStorageService.cs
@@ -2,17 +2,20 @@
using System.IO;
using System;
using Bit.Core.Models.Table;
+using Bit.Core.Enums;
namespace Bit.Core.Services
{
public class NoopSendFileStorageService : ISendFileStorageService
{
+ public FileUploadType FileUploadType => FileUploadType.Direct;
+
public Task UploadNewFileAsync(Stream stream, Send send, string attachmentId)
{
return Task.FromResult(0);
}
- public Task DeleteFileAsync(string fileId)
+ public Task DeleteFileAsync(Send send, string fileId)
{
return Task.FromResult(0);
}
@@ -27,9 +30,19 @@ public Task DeleteFilesForUserAsync(Guid userId)
return Task.FromResult(0);
}
- public Task GetSendFileDownloadUrlAsync(string fileId)
+ public Task GetSendFileDownloadUrlAsync(Send send, string fileId)
+ {
+ return Task.FromResult((string)null);
+ }
+
+ public Task GetSendFileUploadUrlAsync(Send send, string fileId)
{
return Task.FromResult((string)null);
}
+
+ public Task ValidateFileAsync(Send send, string fileId, long expectedFileSize, long leeway)
+ {
+ return Task.FromResult(false);
+ }
}
}