Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Store authentication info in keyvault #3127

Merged
merged 26 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/ApiService/ApiService/Functions/ReproVmss.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,17 @@ public class ReproVmss {
if (vm == null) {
return await _context.RequestHandling.NotOk(req, Error.Create(ErrorCode.INVALID_REQUEST, "no such VM"), $"{request.OkV.VmId}");
}
var auth = await _context.SecretsOperations.GetSecretValue<Authentication>(vm.Auth);

if (auth == null) {
return await _context.RequestHandling.NotOk(req, Error.Create(ErrorCode.INVALID_REQUEST, "no auth info for the VM"), $"{request.OkV.VmId}");
}
var response = req.CreateResponse(HttpStatusCode.OK);
await response.WriteAsJsonAsync(vm);
await response.WriteAsJsonAsync(ReproVmResponse.FromRepro(vm, auth));
return response;
}

var vms = _context.ReproOperations.SearchStates(VmStateHelper.Available).Select(vm => vm with { Auth = null });
var vms = _context.ReproOperations.SearchStates(VmStateHelper.Available);
var response2 = req.CreateResponse(HttpStatusCode.OK);
await response2.WriteAsJsonAsync(vms);
return response2;
Expand Down Expand Up @@ -83,7 +87,15 @@ public class ReproVmss {
"repro_vm create");
}

// we’d like to track the usage of this feature;
var auth = await _context.SecretsOperations.GetSecretValue<Authentication>(vm.OkV.Auth);
if (auth is null) {
return await _context.RequestHandling.NotOk(
req,
Error.Create(ErrorCode.INVALID_REQUEST, "unable to find auth"),
"repro_vm create");
}

// we’d like to track the usage of this feature;
// anonymize the user ID so we can distinguish multiple requests
{
var data = userInfo.OkV.UserInfo.ToString(); // rely on record ToString
Expand All @@ -92,7 +104,8 @@ public class ReproVmss {
}

var response = req.CreateResponse(HttpStatusCode.OK);
await response.WriteAsJsonAsync(vm.OkV);

await response.WriteAsJsonAsync(ReproVmResponse.FromRepro(vm.OkV, auth));
return response;
}

Expand Down Expand Up @@ -127,8 +140,11 @@ public class ReproVmss {
_log.WithHttpStatus(r.ErrorV).Error($"Failed to replace repro {updatedRepro.VmId:Tag:VmId}");
}

if (vm.Auth != null) {
await _context.SecretsOperations.DeleteSecret(vm.Auth);
}
var response = req.CreateResponse(HttpStatusCode.OK);
await response.WriteAsJsonAsync(updatedRepro);
await response.WriteAsJsonAsync(ReproVmResponse.FromRepro(vm, new Authentication("", "", "")));
return response;
}
}
18 changes: 13 additions & 5 deletions src/ApiService/ApiService/Functions/Scaleset.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public class Scaleset {
ScalesetId: Service.Scaleset.GenerateNewScalesetId(create.PoolName),
State: ScalesetState.Init,
NeedsConfigUpdate: false,
Auth: await Auth.BuildAuth(_log),
Auth: new SecretValue<Authentication>(await Auth.BuildAuth(_log)),
PoolName: create.PoolName,
VmSku: create.VmSku,
Image: image,
Expand Down Expand Up @@ -161,7 +161,8 @@ public class Scaleset {
}

// auth not included on create results, only GET with include_auth set
var response = ScalesetResponse.ForScaleset(scaleset, includeAuth: false);

var response = ScalesetResponse.ForScaleset(scaleset, null);
return await RequestHandling.Ok(req, response);
}

Expand Down Expand Up @@ -195,7 +196,7 @@ public class Scaleset {
scaleset = await _context.ScalesetOperations.SetSize(scaleset, size);
}

var response = ScalesetResponse.ForScaleset(scaleset, includeAuth: false);
var response = ScalesetResponse.ForScaleset(scaleset, null);
return await RequestHandling.Ok(req, response);
}

Expand All @@ -214,15 +215,22 @@ public class Scaleset {

var scaleset = scalesetResult.OkV;

var response = ScalesetResponse.ForScaleset(scaleset, includeAuth: search.IncludeAuth);
Authentication? auth;
auth = scaleset.Auth == null
? null
: search.IncludeAuth
? await _context.SecretsOperations.GetSecretValue<Authentication>(scaleset.Auth)
: null;

var response = ScalesetResponse.ForScaleset(scaleset, auth);
response = response with { Nodes = await _context.ScalesetOperations.GetNodes(scaleset) };
return await RequestHandling.Ok(req, response);
}

var states = search.State ?? Enumerable.Empty<ScalesetState>();
var scalesets = await _context.ScalesetOperations.SearchStates(states).ToListAsync();
// don't return auths during list actions, only 'get'
var result = scalesets.Select(ss => ScalesetResponse.ForScaleset(ss, includeAuth: false));
var result = scalesets.Select(ss => ScalesetResponse.ForScaleset(ss));
return await RequestHandling.Ok(req, result);
}
}
4 changes: 3 additions & 1 deletion src/ApiService/ApiService/Functions/Tasks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ public class Tasks {
_context.NodeTasksOperations.GetNodeAssignments(taskId).ToListAsync().AsTask(),
_context.TaskEventOperations.GetSummary(taskId).ToListAsync().AsTask());

var auth = task.Auth == null ? null : await _context.SecretsOperations.GetSecretValue(task.Auth);

var result = new TaskSearchResult(
JobId: task.JobId,
TaskId: task.TaskId,
State: task.State,
Os: task.Os,
Config: task.Config,
Error: task.Error,
Auth: task.Auth,
Auth: auth,
Heartbeat: task.Heartbeat,
EndTime: task.EndTime,
UserInfo: task.UserInfo,
Expand Down
45 changes: 37 additions & 8 deletions src/ApiService/ApiService/OneFuzzTypes/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public record Proxy
[RowKey] Guid ProxyId,
DateTimeOffset? CreatedTimestamp,
VmState State,
Authentication Auth,
ISecret<Authentication> Auth,
string? Ip,
Error? Error,
string Version,
Expand Down Expand Up @@ -282,7 +282,7 @@ public record Task(
Os Os,
TaskConfig Config,
Error? Error = null,
Authentication? Auth = null,
ISecret<Authentication>? Auth = null,
DateTimeOffset? Heartbeat = null,
DateTimeOffset? EndTime = null,
UserInfo? UserInfo = null) : StatefulEntityBase<TaskState>(State) {
Expand Down Expand Up @@ -422,7 +422,7 @@ public partial record Scaleset(
bool EphemeralOsDisks,
bool NeedsConfigUpdate,
Dictionary<string, string> Tags,
Authentication? Auth = null,
ISecret<Authentication>? Auth = null,
Error? Error = null,
Guid? ClientId = null,
Guid? ClientObjectId = null
Expand Down Expand Up @@ -718,7 +718,7 @@ public record Repro(
[PartitionKey][RowKey] Guid VmId,
Guid TaskId,
ReproConfig Config,
Authentication? Auth,
ISecret<Authentication> Auth,
Os Os,
VmState State = VmState.Init,
Error? Error = null,
Expand Down Expand Up @@ -788,15 +788,23 @@ public record Vm(
Region Region,
string Sku,
ImageReference Image,
Authentication Auth,
ISecret<Authentication> Auth,
Nsg? Nsg,
IDictionary<string, string>? Tags
) {
public string Name { get; } = Name.Length > 40 ? throw new ArgumentOutOfRangeException("VM name too long") : Name;
};


public interface ISecret {
[JsonIgnore]
bool IsHIddden { get; }
[JsonIgnore]
Uri? Uri { get; }
string? GetValue();
}
[JsonConverter(typeof(ISecretConverterFactory))]
public interface ISecret<T> { }
public interface ISecret<T> : ISecret { }

public class ISecretConverterFactory : JsonConverterFactory {
public override bool CanConvert(Type typeToConvert) {
Expand Down Expand Up @@ -841,9 +849,30 @@ public class ISecretConverter<T> : JsonConverter<ISecret<T>> {



public record SecretValue<T>(T Value) : ISecret<T>;
public record SecretValue<T>(T Value) : ISecret<T> {
[JsonIgnore]
public bool IsHIddden => false;
[JsonIgnore]
public Uri? Uri => null;

public string? GetValue() {
if (Value is string secretString) {
return secretString.Trim();
}

return JsonSerializer.Serialize(Value, EntityConverter.GetJsonSerializerOptions());
}
}

public record SecretAddress<T>(Uri Url) : ISecret<T> {
[JsonIgnore]
public Uri? Uri => Url;
[JsonIgnore]
public bool IsHIddden => true;
public string? GetValue() => null;

public record SecretAddress<T>(Uri Url) : ISecret<T>;

}

public record SecretData<T>(ISecret<T> Secret) {
}
Expand Down
34 changes: 32 additions & 2 deletions src/ApiService/ApiService/OneFuzzTypes/Responses.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ public record ScalesetResponse(
Dictionary<string, string> Tags,
List<ScalesetNodeState>? Nodes
) : BaseResponse() {
public static ScalesetResponse ForScaleset(Scaleset s, bool includeAuth)
public static ScalesetResponse ForScaleset(Scaleset s, Authentication? auth = null)
=> new(
PoolName: s.PoolName,
ScalesetId: s.ScalesetId,
State: s.State,
Auth: includeAuth ? s.Auth : null,
Auth: auth,
VmSku: s.VmSku,
Image: s.Image,
Region: s.Region,
Expand Down Expand Up @@ -220,3 +220,33 @@ public record NotificationTestResponse(
bool Success,
string? Error = null
) : BaseResponse();


public record ReproVmResponse(
Guid VmId,
Guid TaskId,
ReproConfig Config,
Authentication? Auth,
Os Os,
VmState State = VmState.Init,
Error? Error = null,
string? Ip = null,
DateTimeOffset? EndTime = null,
UserInfo? UserInfo = null
) : BaseResponse() {

public static ReproVmResponse FromRepro(Repro repro, Authentication? auth) {
return new ReproVmResponse(
VmId: repro.VmId,
TaskId: repro.TaskId,
Config: repro.Config,
Auth: auth,
Os: repro.Os,
State: repro.State,
Error: repro.Error,
Ip: repro.Ip,
EndTime: repro.EndTime,
UserInfo: repro.UserInfo
);
}
}
16 changes: 4 additions & 12 deletions src/ApiService/ApiService/TestHooks/TestHooks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,6 @@ public class TestHooks {
}


[Function("GetKeyvaultAddress")]
public async Task<HttpResponseData> GetKeyVaultAddress([HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "testhooks/secrets/keyvaultaddress")] HttpRequestData req) {
_log.Info($"Getting keyvault address");
var addr = _secretOps.GetKeyvaultAddress();
var resp = req.CreateResponse(HttpStatusCode.OK);
await resp.WriteAsJsonAsync(addr);
return resp;
}

[Function("SaveToKeyvault")]
public async Task<HttpResponseData> SaveToKeyvault([HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "testhooks/secrets/keyvault")] HttpRequestData req) {
Expand All @@ -60,10 +52,10 @@ public class TestHooks {
return req.CreateResponse(HttpStatusCode.BadRequest);
} else {
_log.Info($"Saving secret data in the keyvault");
var r = await _secretOps.SaveToKeyvault(secretData);
var addr = _secretOps.GetKeyvaultAddress();
var r = await _secretOps.StoreSecretData(secretData);

var resp = req.CreateResponse(HttpStatusCode.OK);
await resp.WriteAsJsonAsync(addr);
await resp.WriteAsJsonAsync((r.Secret as SecretAddress<string>)?.Url);
return resp;
}
}
Expand All @@ -79,7 +71,7 @@ public class TestHooks {
select new KeyValuePair<string, string>(Uri.UnescapeDataString(cs.Substring(0, i)), Uri.UnescapeDataString(cs.Substring(i + 1)));

var qs = new Dictionary<string, string>(q);
var d = await _secretOps.GetSecretStringValue(new SecretData<string>(new SecretValue<string>(qs["SecretName"])));
var d = await _secretOps.GetSecretValue(new SecretValue<string>(qs["SecretName"]));

var resp = req.CreateResponse(HttpStatusCode.OK);
await resp.WriteAsJsonAsync(d);
Expand Down
6 changes: 5 additions & 1 deletion src/ApiService/ApiService/onefuzzlib/Extension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,11 @@ private sealed class Settings {
var sep = pool.Os == Os.Windows ? "\r\n" : "\n";

if (pool.Os == Os.Windows && scaleSet.Auth is not null) {
var sshKey = scaleSet.Auth.PublicKey.Trim();
var auth = await _context.SecretsOperations.GetSecretValue<Authentication>(scaleSet.Auth);
if (auth is null) {
throw new Exception($"unable to retrieve auth: {scaleSet.Auth}");
}
var sshKey = auth.PublicKey.Trim();
var sshPath = "$env:ProgramData/ssh/administrators_authorized_keys";
commands.Add($"Set-Content -Path {sshPath} -Value \"{sshKey}\"");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ public NotificationOperations(ILogTracer log, IOnefuzzContext context)

switch (notificationTemplate) {
case AdoTemplate adoTemplate:
var hiddenAuthToken = await _context.SecretsOperations.SaveToKeyvault(adoTemplate.AuthToken);
var hiddenAuthToken = await _context.SecretsOperations.StoreSecretData(adoTemplate.AuthToken);
return adoTemplate with { AuthToken = hiddenAuthToken };
case GithubIssuesTemplate githubIssuesTemplate:
var hiddenAuth = await _context.SecretsOperations.SaveToKeyvault(githubIssuesTemplate.Auth);
var hiddenAuth = await _context.SecretsOperations.StoreSecretData(githubIssuesTemplate.Auth);
return githubIssuesTemplate with { Auth = hiddenAuth };
case TeamsTemplate teamsTemplate:
var hiddenUrl = await _context.SecretsOperations.SaveToKeyvault(teamsTemplate.Url);
var hiddenUrl = await _context.SecretsOperations.StoreSecretData(teamsTemplate.Url);
return teamsTemplate with { Url = hiddenUrl };
default:
throw new ArgumentOutOfRangeException(nameof(notificationTemplate));
Expand Down
2 changes: 1 addition & 1 deletion src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public ProxyOperations(ILogTracer log, IOnefuzzContext context)
}

_logTracer.Info($"creating proxy: region:{region:Tag:Region}");
var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, await Auth.BuildAuth(_logTracer), null, null, _context.ServiceConfiguration.OneFuzzVersion, null, false);
var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, new SecretValue<Authentication>(await Auth.BuildAuth(_logTracer)), null, null, _context.ServiceConfiguration.OneFuzzVersion, null, false);

var r = await Replace(newProxy);
if (!r.IsOk) {
Expand Down
16 changes: 10 additions & 6 deletions src/ApiService/ApiService/onefuzzlib/ReproOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ public ReproOperations(ILogTracer log, IOnefuzzContext context)
);
}

if (repro.Auth == null) {
throw new Exception("missing auth");
}

return new Vm(
repro.VmId.ToString(),
vmConfig.Region,
Expand Down Expand Up @@ -260,12 +256,18 @@ await _context.ReproOperations.GetSetupContainer(repro)
}

var files = new Dictionary<string, string>();
var auth = await _context.SecretsOperations.GetSecretValue(repro.Auth);

if (auth == null) {
return OneFuzzResultVoid.Error(ErrorCode.VM_CREATE_FAILED, "unable to fetch auth secret");
}

switch (task.Os) {
case Os.Windows:
var sshPath = "$env:ProgramData/ssh/administrators_authorized_keys";
var cmds = new List<string>()
{
$"Set-Content -Path {sshPath} -Value \"{repro.Auth.PublicKey}\"",
$"Set-Content -Path {sshPath} -Value \"{auth.PublicKey}\"",
". C:\\onefuzz\\tools\\win64\\onefuzz.ps1",
"Set-SetSSHACL",
$"while (1) {{ cdb -server tcp:port=1337 -c \"g\" setup\\{task.Config.Task.TargetExe} {report?.InputBlob?.Name} }}"
Expand Down Expand Up @@ -333,12 +335,14 @@ await _context.ReproOperations.GetSetupContainer(repro)
return OneFuzzResult<Repro>.Error(ErrorCode.INVALID_REQUEST, "unable to find task");
}

var auth = await _context.SecretsOperations.StoreSecret(new SecretValue<Authentication>(await Auth.BuildAuth(_logTracer)));

var vm = new Repro(
VmId: Guid.NewGuid(),
Config: config,
TaskId: task.TaskId,
Os: task.Os,
Auth: await Auth.BuildAuth(_logTracer),
Auth: new SecretAddress<Authentication>(auth),
EndTime: DateTimeOffset.UtcNow + TimeSpan.FromHours(config.Duration),
UserInfo: userInfo);

Expand Down
Loading
Loading